[7.x][ML] Explicitly require a OriginSettingClient in ML results iterators (#50981)

In classes where the client is used directly rather than through a call to 
executeAsyncWithOrigin explicitly require the client to be OriginSettingClient 
rather than using the Client interface. 

Also remove calls to deprecated ClientHelper.clientWithOrigin() method.
This commit is contained in:
David Kyle 2020-01-14 17:14:39 +00:00 committed by GitHub
parent a5a8b60d78
commit 7f309a18f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 228 additions and 180 deletions

View File

@ -12,6 +12,7 @@ import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
@ -46,7 +47,7 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final String executor; private final String executor;
private final Client client; private final OriginSettingClient client;
private final ClusterService clusterService; private final ClusterService clusterService;
private final Clock clock; private final Clock clock;
@ -62,7 +63,7 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
super(DeleteExpiredDataAction.NAME, transportService, actionFilters, DeleteExpiredDataAction.Request::new, executor); super(DeleteExpiredDataAction.NAME, transportService, actionFilters, DeleteExpiredDataAction.Request::new, executor);
this.threadPool = threadPool; this.threadPool = threadPool;
this.executor = executor; this.executor = executor;
this.client = ClientHelper.clientWithOrigin(client, ClientHelper.ML_ORIGIN); this.client = new OriginSettingClient(client, ClientHelper.ML_ORIGIN);
this.clusterService = clusterService; this.clusterService = clusterService;
this.clock = clock; this.clock = clock;
} }

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence; package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -22,7 +22,7 @@ import java.io.InputStream;
class BatchedBucketsIterator extends BatchedResultsIterator<Bucket> { class BatchedBucketsIterator extends BatchedResultsIterator<Bucket> {
BatchedBucketsIterator(Client client, String jobId) { BatchedBucketsIterator(OriginSettingClient client, String jobId) {
super(client, jobId, Bucket.RESULT_TYPE_VALUE); super(client, jobId, Bucket.RESULT_TYPE_VALUE);
} }

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence; package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -21,7 +21,7 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
class BatchedInfluencersIterator extends BatchedResultsIterator<Influencer> { class BatchedInfluencersIterator extends BatchedResultsIterator<Influencer> {
BatchedInfluencersIterator(Client client, String jobId) { BatchedInfluencersIterator(OriginSettingClient client, String jobId) {
super(client, jobId, Influencer.RESULT_TYPE_VALUE); super(client, jobId, Influencer.RESULT_TYPE_VALUE);
} }

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence; package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
@ -23,7 +23,7 @@ import java.io.InputStream;
public class BatchedJobsIterator extends BatchedDocumentsIterator<Job.Builder> { public class BatchedJobsIterator extends BatchedDocumentsIterator<Job.Builder> {
public BatchedJobsIterator(Client client, String index) { public BatchedJobsIterator(OriginSettingClient client, String index) {
super(client, index); super(client, index);
} }

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence; package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -22,7 +22,7 @@ import java.io.InputStream;
class BatchedRecordsIterator extends BatchedResultsIterator<AnomalyRecord> { class BatchedRecordsIterator extends BatchedResultsIterator<AnomalyRecord> {
BatchedRecordsIterator(Client client, String jobId) { BatchedRecordsIterator(OriginSettingClient client, String jobId) {
super(client, jobId, AnomalyRecord.RESULT_TYPE_VALUE); super(client, jobId, AnomalyRecord.RESULT_TYPE_VALUE);
} }

View File

@ -5,7 +5,7 @@
*/ */
package org.elasticsearch.xpack.ml.job.persistence; package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@ -16,7 +16,7 @@ public abstract class BatchedResultsIterator<T> extends BatchedDocumentsIterator
private final ResultsFilterBuilder filterBuilder; private final ResultsFilterBuilder filterBuilder;
public BatchedResultsIterator(Client client, String jobId, String resultType) { public BatchedResultsIterator(OriginSettingClient client, String jobId, String resultType) {
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId)); super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId));
this.filterBuilder = new ResultsFilterBuilder(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType)); this.filterBuilder = new ResultsFilterBuilder(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
} }

View File

@ -5,7 +5,7 @@
*/ */
package org.elasticsearch.xpack.ml.job.persistence; package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
@ -16,7 +16,7 @@ import org.elasticsearch.xpack.ml.utils.persistence.BatchedDocumentsIterator;
*/ */
public class BatchedStateDocIdsIterator extends BatchedDocumentsIterator<String> { public class BatchedStateDocIdsIterator extends BatchedDocumentsIterator<String> {
public BatchedStateDocIdsIterator(Client client, String index) { public BatchedStateDocIdsIterator(OriginSettingClient client, String index) {
super(client, index); super(client, index);
} }

View File

@ -37,6 +37,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlock; import org.elasticsearch.cluster.block.ClusterBlock;
import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockException;
@ -132,7 +133,6 @@ import java.util.stream.Collectors;
import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.clientWithOrigin;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class JobResultsProvider { public class JobResultsProvider {
@ -720,7 +720,7 @@ public class JobResultsProvider {
* @return a bucket {@link BatchedResultsIterator} * @return a bucket {@link BatchedResultsIterator}
*/ */
public BatchedResultsIterator<Bucket> newBatchedBucketsIterator(String jobId) { public BatchedResultsIterator<Bucket> newBatchedBucketsIterator(String jobId) {
return new BatchedBucketsIterator(clientWithOrigin(client, ML_ORIGIN), jobId); return new BatchedBucketsIterator(new OriginSettingClient(client, ML_ORIGIN), jobId);
} }
/** /**
@ -732,7 +732,7 @@ public class JobResultsProvider {
* @return a record {@link BatchedResultsIterator} * @return a record {@link BatchedResultsIterator}
*/ */
public BatchedResultsIterator<AnomalyRecord> newBatchedRecordsIterator(String jobId) { public BatchedResultsIterator<AnomalyRecord> newBatchedRecordsIterator(String jobId) {
return new BatchedRecordsIterator(clientWithOrigin(client, ML_ORIGIN), jobId); return new BatchedRecordsIterator(new OriginSettingClient(client, ML_ORIGIN), jobId);
} }
/** /**
@ -929,7 +929,7 @@ public class JobResultsProvider {
* @return an influencer {@link BatchedResultsIterator} * @return an influencer {@link BatchedResultsIterator}
*/ */
public BatchedResultsIterator<Influencer> newBatchedInfluencersIterator(String jobId) { public BatchedResultsIterator<Influencer> newBatchedInfluencersIterator(String jobId) {
return new BatchedInfluencersIterator(clientWithOrigin(client, ML_ORIGIN), jobId); return new BatchedInfluencersIterator(new OriginSettingClient(client, ML_ORIGIN), jobId);
} }
/** /**

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.retention; package org.elasticsearch.xpack.ml.job.retention;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
@ -34,9 +34,9 @@ import java.util.stream.Collectors;
*/ */
abstract class AbstractExpiredJobDataRemover implements MlDataRemover { abstract class AbstractExpiredJobDataRemover implements MlDataRemover {
private final Client client; private final OriginSettingClient client;
AbstractExpiredJobDataRemover(Client client) { AbstractExpiredJobDataRemover(OriginSettingClient client) {
this.client = client; this.client = client;
} }

View File

@ -13,7 +13,7 @@ import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
@ -62,11 +62,11 @@ public class ExpiredForecastsRemover implements MlDataRemover {
private static final int MAX_FORECASTS = 10000; private static final int MAX_FORECASTS = 10000;
private static final String RESULTS_INDEX_PATTERN = AnomalyDetectorsIndex.jobResultsIndexPrefix() + "*"; private static final String RESULTS_INDEX_PATTERN = AnomalyDetectorsIndex.jobResultsIndexPrefix() + "*";
private final Client client; private final OriginSettingClient client;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final long cutoffEpochMs; private final long cutoffEpochMs;
public ExpiredForecastsRemover(Client client, ThreadPool threadPool) { public ExpiredForecastsRemover(OriginSettingClient client, ThreadPool threadPool) {
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool); this.threadPool = Objects.requireNonNull(threadPool);
this.cutoffEpochMs = Instant.now(Clock.systemDefaultZone()).toEpochMilli(); this.cutoffEpochMs = Instant.now(Clock.systemDefaultZone()).toEpochMilli();

View File

@ -14,7 +14,7 @@ import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
@ -55,10 +55,10 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover
*/ */
private static final int MODEL_SNAPSHOT_SEARCH_SIZE = 10000; private static final int MODEL_SNAPSHOT_SEARCH_SIZE = 10000;
private final Client client; private final OriginSettingClient client;
private final ThreadPool threadPool; private final ThreadPool threadPool;
public ExpiredModelSnapshotsRemover(Client client, ThreadPool threadPool) { public ExpiredModelSnapshotsRemover(OriginSettingClient client, ThreadPool threadPool) {
super(client); super(client);
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool); this.threadPool = Objects.requireNonNull(threadPool);

View File

@ -9,7 +9,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest; import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest;
@ -46,10 +46,10 @@ public class ExpiredResultsRemover extends AbstractExpiredJobDataRemover {
private static final Logger LOGGER = LogManager.getLogger(ExpiredResultsRemover.class); private static final Logger LOGGER = LogManager.getLogger(ExpiredResultsRemover.class);
private final Client client; private final OriginSettingClient client;
private final AnomalyDetectionAuditor auditor; private final AnomalyDetectionAuditor auditor;
public ExpiredResultsRemover(Client client, AnomalyDetectionAuditor auditor) { public ExpiredResultsRemover(OriginSettingClient client, AnomalyDetectionAuditor auditor) {
super(client); super(client);
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.auditor = Objects.requireNonNull(auditor); this.auditor = Objects.requireNonNull(auditor);

View File

@ -9,7 +9,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
@ -49,10 +49,10 @@ public class UnusedStateRemover implements MlDataRemover {
private static final Logger LOGGER = LogManager.getLogger(UnusedStateRemover.class); private static final Logger LOGGER = LogManager.getLogger(UnusedStateRemover.class);
private final Client client; private final OriginSettingClient client;
private final ClusterService clusterService; private final ClusterService clusterService;
public UnusedStateRemover(Client client, ClusterService clusterService) { public UnusedStateRemover(OriginSettingClient client, ClusterService clusterService) {
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.clusterService = Objects.requireNonNull(clusterService); this.clusterService = Objects.requireNonNull(clusterService);
} }

View File

@ -10,7 +10,7 @@ import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
@ -34,14 +34,14 @@ public abstract class BatchedDocumentsIterator<T> {
private static final String CONTEXT_ALIVE_DURATION = "5m"; private static final String CONTEXT_ALIVE_DURATION = "5m";
private static final int BATCH_SIZE = 10000; private static final int BATCH_SIZE = 10000;
private final Client client; private final OriginSettingClient client;
private final String index; private final String index;
private volatile long count; private volatile long count;
private volatile long totalHits; private volatile long totalHits;
private volatile String scrollId; private volatile String scrollId;
private volatile boolean isScrollInitialised; private volatile boolean isScrollInitialised;
protected BatchedDocumentsIterator(Client client, String index) { protected BatchedDocumentsIterator(OriginSettingClient client, String index) {
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.index = Objects.requireNonNull(index); this.index = Objects.requireNonNull(index);
this.totalHits = 0; this.totalHits = 0;

View File

@ -5,7 +5,7 @@
*/ */
package org.elasticsearch.xpack.ml.utils.persistence; package org.elasticsearch.xpack.ml.utils.persistence;
import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
@ -18,7 +18,7 @@ public class DocIdBatchedDocumentIterator extends BatchedDocumentsIterator<Strin
private final QueryBuilder query; private final QueryBuilder query;
public DocIdBatchedDocumentIterator(Client client, String index, QueryBuilder query) { public DocIdBatchedDocumentIterator(OriginSettingClient client, String index, QueryBuilder query) {
super(client, index); super(client, index);
this.query = Objects.requireNonNull(query); this.query = Objects.requireNonNull(query);
} }

View File

@ -8,7 +8,9 @@ package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.results.Result; import org.elasticsearch.xpack.core.ml.job.results.Result;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import java.util.Deque; import java.util.Deque;
import java.util.List; import java.util.List;
@ -25,7 +27,7 @@ public class MockBatchedDocumentsIterator<T> extends BatchedResultsIterator<T> {
private Boolean requireIncludeInterim; private Boolean requireIncludeInterim;
public MockBatchedDocumentsIterator(List<Deque<Result<T>>> batches, String resultType) { public MockBatchedDocumentsIterator(List<Deque<Result<T>>> batches, String resultType) {
super(mock(Client.class), "foo", resultType); super(MockOriginSettingClient.mockOriginSettingClient(mock(Client.class), ClientHelper.ML_ORIGIN), "foo", resultType);
this.batches = batches; this.batches = batches;
index = 0; index = 0;
wasTimeRangeCalled = false; wasTimeRangeCalled = false;

View File

@ -59,12 +59,12 @@ public class ScoresUpdaterTests extends ESTestCase {
private Job job; private Job job;
private ScoresUpdater scoresUpdater; private ScoresUpdater scoresUpdater;
private Bucket generateBucket(Date timestamp) throws IOException { private Bucket generateBucket(Date timestamp) {
return new Bucket(JOB_ID, timestamp, DEFAULT_BUCKET_SPAN); return new Bucket(JOB_ID, timestamp, DEFAULT_BUCKET_SPAN);
} }
@Before @Before
public void setUpMocks() throws IOException { public void setUpMocks() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
Job.Builder jobBuilder = new Job.Builder(JOB_ID); Job.Builder jobBuilder = new Job.Builder(JOB_ID);

View File

@ -6,10 +6,11 @@
package org.elasticsearch.xpack.ml.job.retention; package org.elasticsearch.xpack.ml.job.retention;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
@ -17,8 +18,10 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobTests; import org.elasticsearch.xpack.core.ml.job.config.JobTests;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
@ -32,6 +35,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -44,7 +48,7 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
private int getRetentionDaysCallCount = 0; private int getRetentionDaysCallCount = 0;
ConcreteExpiredJobDataRemover(Client client) { ConcreteExpiredJobDataRemover(OriginSettingClient client) {
super(client); super(client);
} }
@ -61,17 +65,30 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
} }
} }
private OriginSettingClient originSettingClient;
private Client client; private Client client;
@Before @Before
public void setUpTests() { public void setUpTests() {
client = mock(Client.class); client = mock(Client.class);
originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
} }
static SearchResponse createSearchResponse(List<? extends ToXContent> toXContents) throws IOException { static SearchResponse createSearchResponse(List<? extends ToXContent> toXContents) throws IOException {
return createSearchResponse(toXContents, toXContents.size()); return createSearchResponse(toXContents, toXContents.size());
} }
@SuppressWarnings("unchecked")
static void givenJobs(Client client, List<Job> jobs) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(response);
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
}
private static SearchResponse createSearchResponse(List<? extends ToXContent> toXContents, int totalHits) throws IOException { private static SearchResponse createSearchResponse(List<? extends ToXContent> toXContents, int totalHits) throws IOException {
SearchHit[] hitsArray = new SearchHit[toXContents.size()]; SearchHit[] hitsArray = new SearchHit[toXContents.size()];
for (int i = 0; i < toXContents.size(); i++) { for (int i = 0; i < toXContents.size(); i++) {
@ -88,14 +105,10 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
public void testRemoveGivenNoJobs() throws IOException { public void testRemoveGivenNoJobs() throws IOException {
SearchResponse response = createSearchResponse(Collections.emptyList()); SearchResponse response = createSearchResponse(Collections.emptyList());
mockSearchResponse(response);
@SuppressWarnings("unchecked")
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
TestListener listener = new TestListener(); TestListener listener = new TestListener();
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client); ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient);
remover.remove(listener, () -> false); remover.remove(listener, () -> false);
listener.waitToCompletion(); listener.waitToCompletion();
@ -103,6 +116,7 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
assertEquals(0, remover.getRetentionDaysCallCount); assertEquals(0, remover.getRetentionDaysCallCount);
} }
@SuppressWarnings("unchecked")
public void testRemoveGivenMultipleBatches() throws IOException { public void testRemoveGivenMultipleBatches() throws IOException {
// This is testing AbstractExpiredJobDataRemover.WrappedBatchedJobsIterator // This is testing AbstractExpiredJobDataRemover.WrappedBatchedJobsIterator
int totalHits = 7; int totalHits = 7;
@ -126,13 +140,14 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
AtomicInteger searchCount = new AtomicInteger(0); AtomicInteger searchCount = new AtomicInteger(0);
@SuppressWarnings("unchecked") doAnswer(invocationOnMock -> {
ActionFuture<SearchResponse> future = mock(ActionFuture.class); ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
doAnswer(invocationOnMock -> responses.get(searchCount.getAndIncrement())).when(future).actionGet(); listener.onResponse(responses.get(searchCount.getAndIncrement()));
when(client.search(any())).thenReturn(future); return null;
}).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
TestListener listener = new TestListener(); TestListener listener = new TestListener();
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client); ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient);
remover.remove(listener, () -> false); remover.remove(listener, () -> false);
listener.waitToCompletion(); listener.waitToCompletion();
@ -153,13 +168,10 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
final int timeoutAfter = randomIntBetween(0, totalHits - 1); final int timeoutAfter = randomIntBetween(0, totalHits - 1);
AtomicInteger attemptsLeft = new AtomicInteger(timeoutAfter); AtomicInteger attemptsLeft = new AtomicInteger(timeoutAfter);
@SuppressWarnings("unchecked") mockSearchResponse(response);
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
TestListener listener = new TestListener(); TestListener listener = new TestListener();
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client); ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient);
remover.remove(listener, () -> (attemptsLeft.getAndDecrement() <= 0)); remover.remove(listener, () -> (attemptsLeft.getAndDecrement() <= 0));
listener.waitToCompletion(); listener.waitToCompletion();
@ -167,6 +179,15 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
assertEquals(timeoutAfter, remover.getRetentionDaysCallCount); assertEquals(timeoutAfter, remover.getRetentionDaysCallCount);
} }
@SuppressWarnings("unchecked")
private void mockSearchResponse(SearchResponse searchResponse) {
doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(searchResponse);
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
}
static class TestListener implements ActionListener<Boolean> { static class TestListener implements ActionListener<Boolean> {
boolean success; boolean success;

View File

@ -5,24 +5,25 @@
*/ */
package org.elasticsearch.xpack.ml.job.retention; package org.elasticsearch.xpack.ml.job.retention;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.mock.orig.Mockito;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobTests; import org.elasticsearch.xpack.core.ml.job.config.JobTests;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
@ -33,21 +34,23 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.xpack.ml.job.retention.AbstractExpiredJobDataRemoverTests.TestListener; import static org.elasticsearch.xpack.ml.job.retention.AbstractExpiredJobDataRemoverTests.TestListener;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ExpiredModelSnapshotsRemoverTests extends ESTestCase { public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
private Client client; private Client client;
private OriginSettingClient originSettingClient;
private ThreadPool threadPool; private ThreadPool threadPool;
private List<SearchRequest> capturedSearchRequests; private List<SearchRequest> capturedSearchRequests;
private List<DeleteModelSnapshotAction.Request> capturedDeleteModelSnapshotRequests; private List<DeleteModelSnapshotAction.Request> capturedDeleteModelSnapshotRequests;
@ -59,7 +62,10 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
capturedSearchRequests = new ArrayList<>(); capturedSearchRequests = new ArrayList<>();
capturedDeleteModelSnapshotRequests = new ArrayList<>(); capturedDeleteModelSnapshotRequests = new ArrayList<>();
searchResponsesPerCall = new ArrayList<>(); searchResponsesPerCall = new ArrayList<>();
client = mock(Client.class); client = mock(Client.class);
originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
listener = new TestListener(); listener = new TestListener();
// Init thread pool // Init thread pool
@ -76,8 +82,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
} }
public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException { public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException {
givenClientRequestsSucceed(); givenClientRequestsSucceed(Arrays.asList(
givenJobs(Arrays.asList(
JobTests.buildJobBuilder("foo").build(), JobTests.buildJobBuilder("foo").build(),
JobTests.buildJobBuilder("bar").build() JobTests.buildJobBuilder("bar").build()
)); ));
@ -86,25 +91,22 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
listener.waitToCompletion(); listener.waitToCompletion();
assertThat(listener.success, is(true)); assertThat(listener.success, is(true));
verify(client).search(any()); verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
Mockito.verifyNoMoreInteractions(client);
} }
public void testRemove_GivenJobWithoutActiveSnapshot() throws IOException { public void testRemove_GivenJobWithoutActiveSnapshot() throws IOException {
givenClientRequestsSucceed(); givenClientRequestsSucceed(Collections.singletonList(JobTests.buildJobBuilder("foo").setModelSnapshotRetentionDays(7L).build()));
givenJobs(Collections.singletonList(JobTests.buildJobBuilder("foo").setModelSnapshotRetentionDays(7L).build()));
createExpiredModelSnapshotsRemover().remove(listener, () -> false); createExpiredModelSnapshotsRemover().remove(listener, () -> false);
listener.waitToCompletion(); listener.waitToCompletion();
assertThat(listener.success, is(true)); assertThat(listener.success, is(true));
verify(client).search(any()); verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
Mockito.verifyNoMoreInteractions(client);
} }
public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException { public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException {
givenClientRequestsSucceed(); givenClientRequestsSucceed(
givenJobs(Arrays.asList( Arrays.asList(
JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
@ -140,8 +142,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
} }
public void testRemove_GivenTimeout() throws IOException { public void testRemove_GivenTimeout() throws IOException {
givenClientRequestsSucceed(); givenClientRequestsSucceed(
givenJobs(Arrays.asList( Arrays.asList(
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
)); ));
@ -162,8 +164,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
} }
public void testRemove_GivenClientSearchRequestsFail() throws IOException { public void testRemove_GivenClientSearchRequestsFail() throws IOException {
givenClientSearchRequestsFail(); givenClientSearchRequestsFail(
givenJobs(Arrays.asList( Arrays.asList(
JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
@ -188,8 +190,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
} }
public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOException { public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOException {
givenClientDeleteModelSnapshotRequestsFail(); givenClientDeleteModelSnapshotRequestsFail(
givenJobs(Arrays.asList( Arrays.asList(
JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
@ -216,59 +218,47 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-1_1")); assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-1_1"));
} }
@SuppressWarnings("unchecked")
private void givenJobs(List<Job> jobs) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
}
private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover() { private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover() {
return new ExpiredModelSnapshotsRemover(client, threadPool); return new ExpiredModelSnapshotsRemover(originSettingClient, threadPool);
} }
private static ModelSnapshot createModelSnapshot(String jobId, String snapshotId) { private static ModelSnapshot createModelSnapshot(String jobId, String snapshotId) {
return new ModelSnapshot.Builder(jobId).setSnapshotId(snapshotId).build(); return new ModelSnapshot.Builder(jobId).setSnapshotId(snapshotId).build();
} }
// private static SearchResponse createSearchResponse(List<ModelSnapshot> modelSnapshots) throws IOException { private void givenClientRequestsSucceed(List<Job> jobs) throws IOException {
// SearchHit[] hitsArray = new SearchHit[modelSnapshots.size()]; givenClientRequests(jobs, true, true);
// for (int i = 0; i < modelSnapshots.size(); i++) {
// hitsArray[i] = new SearchHit(randomInt());
// XContentBuilder jsonBuilder = JsonXContent.contentBuilder();
// modelSnapshots.get(i).toXContent(jsonBuilder, ToXContent.EMPTY_PARAMS);
// hitsArray[i].sourceRef(BytesReference.bytes(jsonBuilder));
// }
// SearchHits hits = new SearchHits(hitsArray, new TotalHits(hitsArray.length, TotalHits.Relation.EQUAL_TO), 1.0f);
// SearchResponse searchResponse = mock(SearchResponse.class);
// when(searchResponse.getHits()).thenReturn(hits);
// return searchResponse;
// }
private void givenClientRequestsSucceed() {
givenClientRequests(true, true);
} }
private void givenClientSearchRequestsFail() { private void givenClientSearchRequestsFail(List<Job> jobs) throws IOException {
givenClientRequests(false, true); givenClientRequests(jobs, false, true);
} }
private void givenClientDeleteModelSnapshotRequestsFail() { private void givenClientDeleteModelSnapshotRequestsFail(List<Job> jobs) throws IOException {
givenClientRequests(true, false); givenClientRequests(jobs, true, false);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void givenClientRequests(boolean shouldSearchRequestsSucceed, boolean shouldDeleteSnapshotRequestsSucceed) { private void givenClientRequests(List<Job> jobs,
boolean shouldSearchRequestsSucceed, boolean shouldDeleteSnapshotRequestsSucceed) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
int callCount = 0; int callCount = 0;
AtomicBoolean isJobQuery = new AtomicBoolean(true);
@Override @Override
public Void answer(InvocationOnMock invocationOnMock) { public Void answer(InvocationOnMock invocationOnMock) {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
if (isJobQuery.get()) {
listener.onResponse(response);
isJobQuery.set(false);
return null;
}
SearchRequest searchRequest = (SearchRequest) invocationOnMock.getArguments()[1]; SearchRequest searchRequest = (SearchRequest) invocationOnMock.getArguments()[1];
capturedSearchRequests.add(searchRequest); capturedSearchRequests.add(searchRequest);
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
if (shouldSearchRequestsSucceed) { if (shouldSearchRequestsSucceed) {
listener.onResponse(searchResponsesPerCall.get(callCount++)); listener.onResponse(searchResponsesPerCall.get(callCount++));
} else { } else {
@ -277,6 +267,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
return null; return null;
} }
}).when(client).execute(same(SearchAction.INSTANCE), any(), any()); }).when(client).execute(same(SearchAction.INSTANCE), any(), any());
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override @Override
public Void answer(InvocationOnMock invocationOnMock) { public Void answer(InvocationOnMock invocationOnMock) {

View File

@ -5,22 +5,19 @@
*/ */
package org.elasticsearch.xpack.ml.job.retention; package org.elasticsearch.xpack.ml.job.retention;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.reindex.BulkByScrollResponse; import org.elasticsearch.index.reindex.BulkByScrollResponse;
import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.mock.orig.Mockito;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobTests; import org.elasticsearch.xpack.core.ml.job.config.JobTests;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.junit.Before; import org.junit.Before;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -34,6 +31,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -43,6 +41,7 @@ import static org.mockito.Mockito.when;
public class ExpiredResultsRemoverTests extends ESTestCase { public class ExpiredResultsRemoverTests extends ESTestCase {
private Client client; private Client client;
private OriginSettingClient originSettingClient;
private List<DeleteByQueryRequest> capturedDeleteByQueryRequests; private List<DeleteByQueryRequest> capturedDeleteByQueryRequests;
private ActionListener<Boolean> listener; private ActionListener<Boolean> listener;
@ -50,37 +49,26 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void setUpTests() { public void setUpTests() {
capturedDeleteByQueryRequests = new ArrayList<>(); capturedDeleteByQueryRequests = new ArrayList<>();
client = mock(Client.class); client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class); originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
capturedDeleteByQueryRequests.add((DeleteByQueryRequest) invocationOnMock.getArguments()[1]);
ActionListener<BulkByScrollResponse> listener =
(ActionListener<BulkByScrollResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(null);
return null;
}
}).when(client).execute(same(DeleteByQueryAction.INSTANCE), any(), any());
listener = mock(ActionListener.class); listener = mock(ActionListener.class);
} }
public void testRemove_GivenNoJobs() throws IOException { public void testRemove_GivenNoJobs() throws IOException {
givenClientRequestsSucceed(); givenClientRequestsSucceed();
givenJobs(Collections.emptyList()); AbstractExpiredJobDataRemoverTests.givenJobs(client, Collections.emptyList());
createExpiredResultsRemover().remove(listener, () -> false); createExpiredResultsRemover().remove(listener, () -> false);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
verify(listener).onResponse(true); verify(listener).onResponse(true);
verify(client).search(any());
Mockito.verifyNoMoreInteractions(client);
} }
public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException { public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException {
givenClientRequestsSucceed(); givenClientRequestsSucceed();
givenJobs(Arrays.asList( AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("foo").build(), JobTests.buildJobBuilder("foo").build(),
JobTests.buildJobBuilder("bar").build() JobTests.buildJobBuilder("bar").build()
)); ));
@ -88,13 +76,13 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
createExpiredResultsRemover().remove(listener, () -> false); createExpiredResultsRemover().remove(listener, () -> false);
verify(listener).onResponse(true); verify(listener).onResponse(true);
verify(client).search(any()); verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
Mockito.verifyNoMoreInteractions(client);
} }
public void testRemove_GivenJobsWithAndWithoutRetentionPolicy() throws Exception { public void testRemove_GivenJobsWithAndWithoutRetentionPolicy() throws Exception {
givenClientRequestsSucceed(); givenClientRequestsSucceed();
givenJobs(Arrays.asList( AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(), JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build() JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
@ -112,7 +100,8 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
public void testRemove_GivenTimeout() throws Exception { public void testRemove_GivenTimeout() throws Exception {
givenClientRequestsSucceed(); givenClientRequestsSucceed();
givenJobs(Arrays.asList( AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(), JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build() JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
)); ));
@ -128,7 +117,8 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
public void testRemove_GivenClientRequestsFailed() throws IOException { public void testRemove_GivenClientRequestsFailed() throws IOException {
givenClientRequestsFailed(); givenClientRequestsFailed();
givenJobs(Arrays.asList( AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(), JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build() JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
@ -154,7 +144,7 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
private void givenClientRequests(boolean shouldSucceed) { private void givenClientRequests(boolean shouldSucceed) {
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override @Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable { public Void answer(InvocationOnMock invocationOnMock) {
capturedDeleteByQueryRequests.add((DeleteByQueryRequest) invocationOnMock.getArguments()[1]); capturedDeleteByQueryRequests.add((DeleteByQueryRequest) invocationOnMock.getArguments()[1]);
ActionListener<BulkByScrollResponse> listener = ActionListener<BulkByScrollResponse> listener =
(ActionListener<BulkByScrollResponse>) invocationOnMock.getArguments()[2]; (ActionListener<BulkByScrollResponse>) invocationOnMock.getArguments()[2];
@ -170,16 +160,7 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
}).when(client).execute(same(DeleteByQueryAction.INSTANCE), any(), any()); }).when(client).execute(same(DeleteByQueryAction.INSTANCE), any(), any());
} }
@SuppressWarnings("unchecked")
private void givenJobs(List<Job> jobs) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
}
private ExpiredResultsRemover createExpiredResultsRemover() { private ExpiredResultsRemover createExpiredResultsRemover() {
return new ExpiredResultsRemover(client, mock(AnomalyDetectionAuditor.class)); return new ExpiredResultsRemover(originSettingClient, mock(AnomalyDetectionAuditor.class));
} }
} }

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.test;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.mockito.Mockito;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* OriginSettingClient is a final class that cannot be mocked by mockito.
* The solution is to wrap a non-mocked OriginSettingClient around a
* mocked Client. All the mocking should take place on the client parameter.
*/
public class MockOriginSettingClient {
/**
* Create an OriginSettingClient on a mocked client.
*
* @param client The mocked client
* @param origin Whatever
* @return An OriginSettingClient using a mocked client
*/
public static OriginSettingClient mockOriginSettingClient(Client client, String origin) {
if (Mockito.mockingDetails(client).isMock() == false) {
throw new AssertionError("client should be a mock");
}
ThreadContext tc = new ThreadContext(Settings.EMPTY);
ThreadPool tp = mock(ThreadPool.class);
when(tp.getThreadContext()).thenReturn(tc);
when(client.threadPool()).thenReturn(tp);
when(client.settings()).thenReturn(Settings.EMPTY);
return new OriginSettingClient(client, origin);
}
}

View File

@ -6,12 +6,16 @@
package org.elasticsearch.xpack.ml.utils.persistence; package org.elasticsearch.xpack.ml.utils.persistence;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.ClearScrollRequestBuilder; import org.elasticsearch.action.search.ClearScrollAction;
import org.elasticsearch.action.search.ClearScrollResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollAction;
import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
@ -19,6 +23,8 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -30,9 +36,13 @@ import java.util.Collections;
import java.util.Deque; import java.util.Deque;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -42,6 +52,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
private static final String SCROLL_ID = "someScrollId"; private static final String SCROLL_ID = "someScrollId";
private Client client; private Client client;
private OriginSettingClient originSettingClient;
private boolean wasScrollCleared; private boolean wasScrollCleared;
private TestIterator testIterator; private TestIterator testIterator;
@ -52,8 +63,9 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
@Before @Before
public void setUpMocks() { public void setUpMocks() {
client = Mockito.mock(Client.class); client = Mockito.mock(Client.class);
originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
wasScrollCleared = false; wasScrollCleared = false;
testIterator = new TestIterator(client, INDEX_NAME); testIterator = new TestIterator(originSettingClient, INDEX_NAME);
givenClearScrollRequest(); givenClearScrollRequest();
} }
@ -122,14 +134,14 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
return "{\"foo\":\"" + value + "\"}"; return "{\"foo\":\"" + value + "\"}";
} }
@SuppressWarnings("unchecked")
private void givenClearScrollRequest() { private void givenClearScrollRequest() {
ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class); doAnswer(invocationOnMock -> {
when(client.prepareClearScroll()).thenReturn(requestBuilder); ActionListener<ClearScrollResponse> listener = (ActionListener<ClearScrollResponse>) invocationOnMock.getArguments()[2];
when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder);
when(requestBuilder.get()).thenAnswer((invocation) -> {
wasScrollCleared = true; wasScrollCleared = true;
listener.onResponse(mock(ClearScrollResponse.class));
return null; return null;
}); }).when(client).execute(eq(ClearScrollAction.INSTANCE), any(), any());
} }
private void assertSearchRequest() { private void assertSearchRequest() {
@ -157,6 +169,8 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
private long totalHits = 0; private long totalHits = 0;
private List<SearchResponse> responses = new ArrayList<>(); private List<SearchResponse> responses = new ArrayList<>();
private AtomicInteger responseIndex = new AtomicInteger(0);
ScrollResponsesMocker addBatch(String... hits) { ScrollResponsesMocker addBatch(String... hits) {
totalHits += hits.length; totalHits += hits.length;
batches.add(hits); batches.add(hits);
@ -174,33 +188,23 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
givenNextResponse(batches.get(i)); givenNextResponse(batches.get(i));
} }
if (responses.size() > 0) { if (responses.size() > 0) {
ActionFuture<SearchResponse> first = wrapResponse(responses.get(0)); doAnswer(invocationOnMock -> {
if (responses.size() > 1) { ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
List<ActionFuture<SearchResponse>> rest = new ArrayList<>(); listener.onResponse(responses.get(responseIndex.getAndIncrement()));
for (int i = 1; i < responses.size(); ++i) { return null;
rest.add(wrapResponse(responses.get(i))); }).when(client).execute(eq(SearchScrollAction.INSTANCE), searchScrollRequestCaptor.capture(), any());
} }
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(
first, rest.toArray(new ActionFuture[rest.size() - 1]));
} else {
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(first);
}
}
}
private void givenInitialResponse(String... hits) {
SearchResponse searchResponse = createSearchResponseWithHits(hits);
ActionFuture<SearchResponse> future = wrapResponse(searchResponse);
when(future.actionGet()).thenReturn(searchResponse);
when(client.search(searchRequestCaptor.capture())).thenReturn(future);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private ActionFuture<SearchResponse> wrapResponse(SearchResponse searchResponse) { private void givenInitialResponse(String... hits) {
ActionFuture<SearchResponse> future = mock(ActionFuture.class); SearchResponse searchResponse = createSearchResponseWithHits(hits);
when(future.actionGet()).thenReturn(searchResponse);
return future; doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(searchResponse);
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), searchRequestCaptor.capture(), any());
} }
private void givenNextResponse(String... hits) { private void givenNextResponse(String... hits) {
@ -225,7 +229,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
} }
private static class TestIterator extends BatchedDocumentsIterator<String> { private static class TestIterator extends BatchedDocumentsIterator<String> {
TestIterator(Client client, String jobId) { TestIterator(OriginSettingClient client, String jobId) {
super(client, jobId); super(client, jobId);
} }