[7.x][ML] Explicitly require a OriginSettingClient in ML results iterators (#50981)

In classes where the client is used directly rather than through a call to 
executeAsyncWithOrigin explicitly require the client to be OriginSettingClient 
rather than using the Client interface. 

Also remove calls to deprecated ClientHelper.clientWithOrigin() method.
This commit is contained in:
David Kyle 2020-01-14 17:14:39 +00:00 committed by GitHub
parent a5a8b60d78
commit 7f309a18f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 228 additions and 180 deletions

View File

@ -12,6 +12,7 @@ import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
@ -46,7 +47,7 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
private final ThreadPool threadPool;
private final String executor;
private final Client client;
private final OriginSettingClient client;
private final ClusterService clusterService;
private final Clock clock;
@ -62,7 +63,7 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
super(DeleteExpiredDataAction.NAME, transportService, actionFilters, DeleteExpiredDataAction.Request::new, executor);
this.threadPool = threadPool;
this.executor = executor;
this.client = ClientHelper.clientWithOrigin(client, ClientHelper.ML_ORIGIN);
this.client = new OriginSettingClient(client, ClientHelper.ML_ORIGIN);
this.clusterService = clusterService;
this.clock = clock;
}

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -22,7 +22,7 @@ import java.io.InputStream;
class BatchedBucketsIterator extends BatchedResultsIterator<Bucket> {
BatchedBucketsIterator(Client client, String jobId) {
BatchedBucketsIterator(OriginSettingClient client, String jobId) {
super(client, jobId, Bucket.RESULT_TYPE_VALUE);
}

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -21,7 +21,7 @@ import java.io.IOException;
import java.io.InputStream;
class BatchedInfluencersIterator extends BatchedResultsIterator<Influencer> {
BatchedInfluencersIterator(Client client, String jobId) {
BatchedInfluencersIterator(OriginSettingClient client, String jobId) {
super(client, jobId, Influencer.RESULT_TYPE_VALUE);
}

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
@ -23,7 +23,7 @@ import java.io.InputStream;
public class BatchedJobsIterator extends BatchedDocumentsIterator<Job.Builder> {
public BatchedJobsIterator(Client client, String index) {
public BatchedJobsIterator(OriginSettingClient client, String index) {
super(client, index);
}

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -22,7 +22,7 @@ import java.io.InputStream;
class BatchedRecordsIterator extends BatchedResultsIterator<AnomalyRecord> {
BatchedRecordsIterator(Client client, String jobId) {
BatchedRecordsIterator(OriginSettingClient client, String jobId) {
super(client, jobId, AnomalyRecord.RESULT_TYPE_VALUE);
}

View File

@ -5,7 +5,7 @@
*/
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@ -16,7 +16,7 @@ public abstract class BatchedResultsIterator<T> extends BatchedDocumentsIterator
private final ResultsFilterBuilder filterBuilder;
public BatchedResultsIterator(Client client, String jobId, String resultType) {
public BatchedResultsIterator(OriginSettingClient client, String jobId, String resultType) {
super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId));
this.filterBuilder = new ResultsFilterBuilder(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType));
}

View File

@ -5,7 +5,7 @@
*/
package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
@ -16,7 +16,7 @@ import org.elasticsearch.xpack.ml.utils.persistence.BatchedDocumentsIterator;
*/
public class BatchedStateDocIdsIterator extends BatchedDocumentsIterator<String> {
public BatchedStateDocIdsIterator(Client client, String index) {
public BatchedStateDocIdsIterator(OriginSettingClient client, String index) {
super(client, index);
}

View File

@ -37,6 +37,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlock;
import org.elasticsearch.cluster.block.ClusterBlockException;
@ -132,7 +133,6 @@ import java.util.stream.Collectors;
import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.clientWithOrigin;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class JobResultsProvider {
@ -720,7 +720,7 @@ public class JobResultsProvider {
* @return a bucket {@link BatchedResultsIterator}
*/
public BatchedResultsIterator<Bucket> newBatchedBucketsIterator(String jobId) {
return new BatchedBucketsIterator(clientWithOrigin(client, ML_ORIGIN), jobId);
return new BatchedBucketsIterator(new OriginSettingClient(client, ML_ORIGIN), jobId);
}
/**
@ -732,7 +732,7 @@ public class JobResultsProvider {
* @return a record {@link BatchedResultsIterator}
*/
public BatchedResultsIterator<AnomalyRecord> newBatchedRecordsIterator(String jobId) {
return new BatchedRecordsIterator(clientWithOrigin(client, ML_ORIGIN), jobId);
return new BatchedRecordsIterator(new OriginSettingClient(client, ML_ORIGIN), jobId);
}
/**
@ -929,7 +929,7 @@ public class JobResultsProvider {
* @return an influencer {@link BatchedResultsIterator}
*/
public BatchedResultsIterator<Influencer> newBatchedInfluencersIterator(String jobId) {
return new BatchedInfluencersIterator(clientWithOrigin(client, ML_ORIGIN), jobId);
return new BatchedInfluencersIterator(new OriginSettingClient(client, ML_ORIGIN), jobId);
}
/**

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.job.retention;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@ -34,9 +34,9 @@ import java.util.stream.Collectors;
*/
abstract class AbstractExpiredJobDataRemover implements MlDataRemover {
private final Client client;
private final OriginSettingClient client;
AbstractExpiredJobDataRemover(Client client) {
AbstractExpiredJobDataRemover(OriginSettingClient client) {
this.client = client;
}

View File

@ -13,7 +13,7 @@ import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
@ -62,11 +62,11 @@ public class ExpiredForecastsRemover implements MlDataRemover {
private static final int MAX_FORECASTS = 10000;
private static final String RESULTS_INDEX_PATTERN = AnomalyDetectorsIndex.jobResultsIndexPrefix() + "*";
private final Client client;
private final OriginSettingClient client;
private final ThreadPool threadPool;
private final long cutoffEpochMs;
public ExpiredForecastsRemover(Client client, ThreadPool threadPool) {
public ExpiredForecastsRemover(OriginSettingClient client, ThreadPool threadPool) {
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
this.cutoffEpochMs = Instant.now(Clock.systemDefaultZone()).toEpochMilli();

View File

@ -14,7 +14,7 @@ import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
@ -55,10 +55,10 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover
*/
private static final int MODEL_SNAPSHOT_SEARCH_SIZE = 10000;
private final Client client;
private final OriginSettingClient client;
private final ThreadPool threadPool;
public ExpiredModelSnapshotsRemover(Client client, ThreadPool threadPool) {
public ExpiredModelSnapshotsRemover(OriginSettingClient client, ThreadPool threadPool) {
super(client);
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);

View File

@ -9,7 +9,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest;
@ -46,10 +46,10 @@ public class ExpiredResultsRemover extends AbstractExpiredJobDataRemover {
private static final Logger LOGGER = LogManager.getLogger(ExpiredResultsRemover.class);
private final Client client;
private final OriginSettingClient client;
private final AnomalyDetectionAuditor auditor;
public ExpiredResultsRemover(Client client, AnomalyDetectionAuditor auditor) {
public ExpiredResultsRemover(OriginSettingClient client, AnomalyDetectionAuditor auditor) {
super(client);
this.client = Objects.requireNonNull(client);
this.auditor = Objects.requireNonNull(auditor);

View File

@ -9,7 +9,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.QueryBuilders;
@ -49,10 +49,10 @@ public class UnusedStateRemover implements MlDataRemover {
private static final Logger LOGGER = LogManager.getLogger(UnusedStateRemover.class);
private final Client client;
private final OriginSettingClient client;
private final ClusterService clusterService;
public UnusedStateRemover(Client client, ClusterService clusterService) {
public UnusedStateRemover(OriginSettingClient client, ClusterService clusterService) {
this.client = Objects.requireNonNull(client);
this.clusterService = Objects.requireNonNull(clusterService);
}

View File

@ -10,7 +10,7 @@ import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
@ -34,14 +34,14 @@ public abstract class BatchedDocumentsIterator<T> {
private static final String CONTEXT_ALIVE_DURATION = "5m";
private static final int BATCH_SIZE = 10000;
private final Client client;
private final OriginSettingClient client;
private final String index;
private volatile long count;
private volatile long totalHits;
private volatile String scrollId;
private volatile boolean isScrollInitialised;
protected BatchedDocumentsIterator(Client client, String index) {
protected BatchedDocumentsIterator(OriginSettingClient client, String index) {
this.client = Objects.requireNonNull(client);
this.index = Objects.requireNonNull(index);
this.totalHits = 0;

View File

@ -5,7 +5,7 @@
*/
package org.elasticsearch.xpack.ml.utils.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit;
@ -18,7 +18,7 @@ public class DocIdBatchedDocumentIterator extends BatchedDocumentsIterator<Strin
private final QueryBuilder query;
public DocIdBatchedDocumentIterator(Client client, String index, QueryBuilder query) {
public DocIdBatchedDocumentIterator(OriginSettingClient client, String index, QueryBuilder query) {
super(client, index);
this.query = Objects.requireNonNull(query);
}

View File

@ -8,7 +8,9 @@ package org.elasticsearch.xpack.ml.job.persistence;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.results.Result;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import java.util.Deque;
import java.util.List;
@ -25,7 +27,7 @@ public class MockBatchedDocumentsIterator<T> extends BatchedResultsIterator<T> {
private Boolean requireIncludeInterim;
public MockBatchedDocumentsIterator(List<Deque<Result<T>>> batches, String resultType) {
super(mock(Client.class), "foo", resultType);
super(MockOriginSettingClient.mockOriginSettingClient(mock(Client.class), ClientHelper.ML_ORIGIN), "foo", resultType);
this.batches = batches;
index = 0;
wasTimeRangeCalled = false;

View File

@ -59,12 +59,12 @@ public class ScoresUpdaterTests extends ESTestCase {
private Job job;
private ScoresUpdater scoresUpdater;
private Bucket generateBucket(Date timestamp) throws IOException {
private Bucket generateBucket(Date timestamp) {
return new Bucket(JOB_ID, timestamp, DEFAULT_BUCKET_SPAN);
}
@Before
public void setUpMocks() throws IOException {
public void setUpMocks() {
MockitoAnnotations.initMocks(this);
Job.Builder jobBuilder = new Job.Builder(JOB_ID);

View File

@ -6,10 +6,11 @@
package org.elasticsearch.xpack.ml.job.retention;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -17,8 +18,10 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobTests;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.junit.Before;
import java.io.IOException;
@ -32,6 +35,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -44,7 +48,7 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
private int getRetentionDaysCallCount = 0;
ConcreteExpiredJobDataRemover(Client client) {
ConcreteExpiredJobDataRemover(OriginSettingClient client) {
super(client);
}
@ -61,17 +65,30 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
}
}
private OriginSettingClient originSettingClient;
private Client client;
@Before
public void setUpTests() {
client = mock(Client.class);
originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
}
static SearchResponse createSearchResponse(List<? extends ToXContent> toXContents) throws IOException {
return createSearchResponse(toXContents, toXContents.size());
}
@SuppressWarnings("unchecked")
static void givenJobs(Client client, List<Job> jobs) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(response);
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
}
private static SearchResponse createSearchResponse(List<? extends ToXContent> toXContents, int totalHits) throws IOException {
SearchHit[] hitsArray = new SearchHit[toXContents.size()];
for (int i = 0; i < toXContents.size(); i++) {
@ -88,14 +105,10 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
public void testRemoveGivenNoJobs() throws IOException {
SearchResponse response = createSearchResponse(Collections.emptyList());
@SuppressWarnings("unchecked")
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
mockSearchResponse(response);
TestListener listener = new TestListener();
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client);
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient);
remover.remove(listener, () -> false);
listener.waitToCompletion();
@ -103,6 +116,7 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
assertEquals(0, remover.getRetentionDaysCallCount);
}
@SuppressWarnings("unchecked")
public void testRemoveGivenMultipleBatches() throws IOException {
// This is testing AbstractExpiredJobDataRemover.WrappedBatchedJobsIterator
int totalHits = 7;
@ -126,13 +140,14 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
AtomicInteger searchCount = new AtomicInteger(0);
@SuppressWarnings("unchecked")
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
doAnswer(invocationOnMock -> responses.get(searchCount.getAndIncrement())).when(future).actionGet();
when(client.search(any())).thenReturn(future);
doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(responses.get(searchCount.getAndIncrement()));
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
TestListener listener = new TestListener();
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client);
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient);
remover.remove(listener, () -> false);
listener.waitToCompletion();
@ -153,13 +168,10 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
final int timeoutAfter = randomIntBetween(0, totalHits - 1);
AtomicInteger attemptsLeft = new AtomicInteger(timeoutAfter);
@SuppressWarnings("unchecked")
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
mockSearchResponse(response);
TestListener listener = new TestListener();
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client);
ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient);
remover.remove(listener, () -> (attemptsLeft.getAndDecrement() <= 0));
listener.waitToCompletion();
@ -167,6 +179,15 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
assertEquals(timeoutAfter, remover.getRetentionDaysCallCount);
}
@SuppressWarnings("unchecked")
private void mockSearchResponse(SearchResponse searchResponse) {
doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(searchResponse);
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
}
static class TestListener implements ActionListener<Boolean> {
boolean success;

View File

@ -5,24 +5,25 @@
*/
package org.elasticsearch.xpack.ml.job.retention;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.mock.orig.Mockito;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobTests;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.junit.After;
import org.junit.Before;
import org.mockito.invocation.InvocationOnMock;
@ -33,21 +34,23 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.xpack.ml.job.retention.AbstractExpiredJobDataRemoverTests.TestListener;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
private Client client;
private OriginSettingClient originSettingClient;
private ThreadPool threadPool;
private List<SearchRequest> capturedSearchRequests;
private List<DeleteModelSnapshotAction.Request> capturedDeleteModelSnapshotRequests;
@ -59,7 +62,10 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
capturedSearchRequests = new ArrayList<>();
capturedDeleteModelSnapshotRequests = new ArrayList<>();
searchResponsesPerCall = new ArrayList<>();
client = mock(Client.class);
originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
listener = new TestListener();
// Init thread pool
@ -76,8 +82,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
}
public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException {
givenClientRequestsSucceed();
givenJobs(Arrays.asList(
givenClientRequestsSucceed(Arrays.asList(
JobTests.buildJobBuilder("foo").build(),
JobTests.buildJobBuilder("bar").build()
));
@ -86,25 +91,22 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
listener.waitToCompletion();
assertThat(listener.success, is(true));
verify(client).search(any());
Mockito.verifyNoMoreInteractions(client);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
}
public void testRemove_GivenJobWithoutActiveSnapshot() throws IOException {
givenClientRequestsSucceed();
givenJobs(Collections.singletonList(JobTests.buildJobBuilder("foo").setModelSnapshotRetentionDays(7L).build()));
givenClientRequestsSucceed(Collections.singletonList(JobTests.buildJobBuilder("foo").setModelSnapshotRetentionDays(7L).build()));
createExpiredModelSnapshotsRemover().remove(listener, () -> false);
listener.waitToCompletion();
assertThat(listener.success, is(true));
verify(client).search(any());
Mockito.verifyNoMoreInteractions(client);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
}
public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException {
givenClientRequestsSucceed();
givenJobs(Arrays.asList(
givenClientRequestsSucceed(
Arrays.asList(
JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
@ -140,8 +142,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
}
public void testRemove_GivenTimeout() throws IOException {
givenClientRequestsSucceed();
givenJobs(Arrays.asList(
givenClientRequestsSucceed(
Arrays.asList(
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
));
@ -162,8 +164,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
}
public void testRemove_GivenClientSearchRequestsFail() throws IOException {
givenClientSearchRequestsFail();
givenJobs(Arrays.asList(
givenClientSearchRequestsFail(
Arrays.asList(
JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
@ -188,8 +190,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
}
public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOException {
givenClientDeleteModelSnapshotRequestsFail();
givenJobs(Arrays.asList(
givenClientDeleteModelSnapshotRequestsFail(
Arrays.asList(
JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
@ -216,59 +218,47 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-1_1"));
}
@SuppressWarnings("unchecked")
private void givenJobs(List<Job> jobs) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
}
private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover() {
return new ExpiredModelSnapshotsRemover(client, threadPool);
return new ExpiredModelSnapshotsRemover(originSettingClient, threadPool);
}
private static ModelSnapshot createModelSnapshot(String jobId, String snapshotId) {
return new ModelSnapshot.Builder(jobId).setSnapshotId(snapshotId).build();
}
// private static SearchResponse createSearchResponse(List<ModelSnapshot> modelSnapshots) throws IOException {
// SearchHit[] hitsArray = new SearchHit[modelSnapshots.size()];
// for (int i = 0; i < modelSnapshots.size(); i++) {
// hitsArray[i] = new SearchHit(randomInt());
// XContentBuilder jsonBuilder = JsonXContent.contentBuilder();
// modelSnapshots.get(i).toXContent(jsonBuilder, ToXContent.EMPTY_PARAMS);
// hitsArray[i].sourceRef(BytesReference.bytes(jsonBuilder));
// }
// SearchHits hits = new SearchHits(hitsArray, new TotalHits(hitsArray.length, TotalHits.Relation.EQUAL_TO), 1.0f);
// SearchResponse searchResponse = mock(SearchResponse.class);
// when(searchResponse.getHits()).thenReturn(hits);
// return searchResponse;
// }
private void givenClientRequestsSucceed() {
givenClientRequests(true, true);
private void givenClientRequestsSucceed(List<Job> jobs) throws IOException {
givenClientRequests(jobs, true, true);
}
private void givenClientSearchRequestsFail() {
givenClientRequests(false, true);
private void givenClientSearchRequestsFail(List<Job> jobs) throws IOException {
givenClientRequests(jobs, false, true);
}
private void givenClientDeleteModelSnapshotRequestsFail() {
givenClientRequests(true, false);
private void givenClientDeleteModelSnapshotRequestsFail(List<Job> jobs) throws IOException {
givenClientRequests(jobs, true, false);
}
@SuppressWarnings("unchecked")
private void givenClientRequests(boolean shouldSearchRequestsSucceed, boolean shouldDeleteSnapshotRequestsSucceed) {
private void givenClientRequests(List<Job> jobs,
boolean shouldSearchRequestsSucceed, boolean shouldDeleteSnapshotRequestsSucceed) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
doAnswer(new Answer<Void>() {
int callCount = 0;
AtomicBoolean isJobQuery = new AtomicBoolean(true);
@Override
public Void answer(InvocationOnMock invocationOnMock) {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
if (isJobQuery.get()) {
listener.onResponse(response);
isJobQuery.set(false);
return null;
}
SearchRequest searchRequest = (SearchRequest) invocationOnMock.getArguments()[1];
capturedSearchRequests.add(searchRequest);
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
if (shouldSearchRequestsSucceed) {
listener.onResponse(searchResponsesPerCall.get(callCount++));
} else {
@ -277,6 +267,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
return null;
}
}).when(client).execute(same(SearchAction.INSTANCE), any(), any());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) {

View File

@ -5,22 +5,19 @@
*/
package org.elasticsearch.xpack.ml.job.retention;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.index.reindex.BulkByScrollResponse;
import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.mock.orig.Mockito;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.config.JobTests;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.junit.Before;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@ -34,6 +31,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
@ -43,6 +41,7 @@ import static org.mockito.Mockito.when;
public class ExpiredResultsRemoverTests extends ESTestCase {
private Client client;
private OriginSettingClient originSettingClient;
private List<DeleteByQueryRequest> capturedDeleteByQueryRequests;
private ActionListener<Boolean> listener;
@ -50,37 +49,26 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void setUpTests() {
capturedDeleteByQueryRequests = new ArrayList<>();
client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
capturedDeleteByQueryRequests.add((DeleteByQueryRequest) invocationOnMock.getArguments()[1]);
ActionListener<BulkByScrollResponse> listener =
(ActionListener<BulkByScrollResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(null);
return null;
}
}).when(client).execute(same(DeleteByQueryAction.INSTANCE), any(), any());
originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
listener = mock(ActionListener.class);
}
public void testRemove_GivenNoJobs() throws IOException {
givenClientRequestsSucceed();
givenJobs(Collections.emptyList());
AbstractExpiredJobDataRemoverTests.givenJobs(client, Collections.emptyList());
createExpiredResultsRemover().remove(listener, () -> false);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
verify(listener).onResponse(true);
verify(client).search(any());
Mockito.verifyNoMoreInteractions(client);
}
public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException {
givenClientRequestsSucceed();
givenJobs(Arrays.asList(
AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("foo").build(),
JobTests.buildJobBuilder("bar").build()
));
@ -88,13 +76,13 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
createExpiredResultsRemover().remove(listener, () -> false);
verify(listener).onResponse(true);
verify(client).search(any());
Mockito.verifyNoMoreInteractions(client);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
}
public void testRemove_GivenJobsWithAndWithoutRetentionPolicy() throws Exception {
givenClientRequestsSucceed();
givenJobs(Arrays.asList(
AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
@ -112,7 +100,8 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
public void testRemove_GivenTimeout() throws Exception {
givenClientRequestsSucceed();
givenJobs(Arrays.asList(
AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
));
@ -128,7 +117,8 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
public void testRemove_GivenClientRequestsFailed() throws IOException {
givenClientRequestsFailed();
givenJobs(Arrays.asList(
AbstractExpiredJobDataRemoverTests.givenJobs(client,
Arrays.asList(
JobTests.buildJobBuilder("none").build(),
JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
@ -154,7 +144,7 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
private void givenClientRequests(boolean shouldSucceed) {
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
public Void answer(InvocationOnMock invocationOnMock) {
capturedDeleteByQueryRequests.add((DeleteByQueryRequest) invocationOnMock.getArguments()[1]);
ActionListener<BulkByScrollResponse> listener =
(ActionListener<BulkByScrollResponse>) invocationOnMock.getArguments()[2];
@ -170,16 +160,7 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
}).when(client).execute(same(DeleteByQueryAction.INSTANCE), any(), any());
}
@SuppressWarnings("unchecked")
private void givenJobs(List<Job> jobs) throws IOException {
SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(response);
when(client.search(any())).thenReturn(future);
}
private ExpiredResultsRemover createExpiredResultsRemover() {
return new ExpiredResultsRemover(client, mock(AnomalyDetectionAuditor.class));
return new ExpiredResultsRemover(originSettingClient, mock(AnomalyDetectionAuditor.class));
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.test;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.mockito.Mockito;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* OriginSettingClient is a final class that cannot be mocked by mockito.
* The solution is to wrap a non-mocked OriginSettingClient around a
* mocked Client. All the mocking should take place on the client parameter.
*/
public class MockOriginSettingClient {
/**
* Create an OriginSettingClient on a mocked client.
*
* @param client The mocked client
* @param origin Whatever
* @return An OriginSettingClient using a mocked client
*/
public static OriginSettingClient mockOriginSettingClient(Client client, String origin) {
if (Mockito.mockingDetails(client).isMock() == false) {
throw new AssertionError("client should be a mock");
}
ThreadContext tc = new ThreadContext(Settings.EMPTY);
ThreadPool tp = mock(ThreadPool.class);
when(tp.getThreadContext()).thenReturn(tc);
when(client.threadPool()).thenReturn(tp);
when(client.settings()).thenReturn(Settings.EMPTY);
return new OriginSettingClient(client, origin);
}
}

View File

@ -6,12 +6,16 @@
package org.elasticsearch.xpack.ml.utils.persistence;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.search.ClearScrollRequestBuilder;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.ClearScrollAction;
import org.elasticsearch.action.search.ClearScrollResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollAction;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@ -19,6 +23,8 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
@ -30,9 +36,13 @@ import java.util.Collections;
import java.util.Deque;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -42,6 +52,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
private static final String SCROLL_ID = "someScrollId";
private Client client;
private OriginSettingClient originSettingClient;
private boolean wasScrollCleared;
private TestIterator testIterator;
@ -52,8 +63,9 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
@Before
public void setUpMocks() {
client = Mockito.mock(Client.class);
originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
wasScrollCleared = false;
testIterator = new TestIterator(client, INDEX_NAME);
testIterator = new TestIterator(originSettingClient, INDEX_NAME);
givenClearScrollRequest();
}
@ -122,14 +134,14 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
return "{\"foo\":\"" + value + "\"}";
}
@SuppressWarnings("unchecked")
private void givenClearScrollRequest() {
ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class);
when(client.prepareClearScroll()).thenReturn(requestBuilder);
when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder);
when(requestBuilder.get()).thenAnswer((invocation) -> {
doAnswer(invocationOnMock -> {
ActionListener<ClearScrollResponse> listener = (ActionListener<ClearScrollResponse>) invocationOnMock.getArguments()[2];
wasScrollCleared = true;
listener.onResponse(mock(ClearScrollResponse.class));
return null;
});
}).when(client).execute(eq(ClearScrollAction.INSTANCE), any(), any());
}
private void assertSearchRequest() {
@ -157,6 +169,8 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
private long totalHits = 0;
private List<SearchResponse> responses = new ArrayList<>();
private AtomicInteger responseIndex = new AtomicInteger(0);
ScrollResponsesMocker addBatch(String... hits) {
totalHits += hits.length;
batches.add(hits);
@ -174,33 +188,23 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
givenNextResponse(batches.get(i));
}
if (responses.size() > 0) {
ActionFuture<SearchResponse> first = wrapResponse(responses.get(0));
if (responses.size() > 1) {
List<ActionFuture<SearchResponse>> rest = new ArrayList<>();
for (int i = 1; i < responses.size(); ++i) {
rest.add(wrapResponse(responses.get(i)));
}
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(
first, rest.toArray(new ActionFuture[rest.size() - 1]));
} else {
when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(first);
}
doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(responses.get(responseIndex.getAndIncrement()));
return null;
}).when(client).execute(eq(SearchScrollAction.INSTANCE), searchScrollRequestCaptor.capture(), any());
}
}
@SuppressWarnings("unchecked")
private void givenInitialResponse(String... hits) {
SearchResponse searchResponse = createSearchResponseWithHits(hits);
ActionFuture<SearchResponse> future = wrapResponse(searchResponse);
when(future.actionGet()).thenReturn(searchResponse);
when(client.search(searchRequestCaptor.capture())).thenReturn(future);
}
@SuppressWarnings("unchecked")
private ActionFuture<SearchResponse> wrapResponse(SearchResponse searchResponse) {
ActionFuture<SearchResponse> future = mock(ActionFuture.class);
when(future.actionGet()).thenReturn(searchResponse);
return future;
doAnswer(invocationOnMock -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(searchResponse);
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), searchRequestCaptor.capture(), any());
}
private void givenNextResponse(String... hits) {
@ -225,7 +229,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
}
private static class TestIterator extends BatchedDocumentsIterator<String> {
TestIterator(Client client, String jobId) {
TestIterator(OriginSettingClient client, String jobId) {
super(client, jobId);
}