This commit is contained in:
parent
922a70ce32
commit
665f0d81aa
|
@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -62,6 +63,11 @@ public class StartDataFrameTransformTaskAction extends Action<StartDataFrameTran
|
|||
out.writeString(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean match(Task task) {
|
||||
return task.getDescription().equals(DataFrameField.PERSISTENT_TASK_DESCRIPTION_PREFIX + id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
return null;
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.elasticsearch.xpack.core.dataframe.action.StartDataFrameTransformTask
|
|||
import org.elasticsearch.xpack.dataframe.transforms.DataFrameTransformTask;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Internal only transport class to change an allocated persistent task's state to started
|
||||
|
@ -44,27 +43,6 @@ public class TransportStartDataFrameTransformTaskAction extends
|
|||
this.licenseState = licenseState;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processTasks(StartDataFrameTransformTaskAction.Request request, Consumer<DataFrameTransformTask> operation) {
|
||||
DataFrameTransformTask matchingTask = null;
|
||||
|
||||
// todo: re-factor, see rollup TransportTaskHelper
|
||||
for (Task task : taskManager.getTasks().values()) {
|
||||
if (task instanceof DataFrameTransformTask
|
||||
&& ((DataFrameTransformTask) task).getTransformId().equals(request.getId())) {
|
||||
if (matchingTask != null) {
|
||||
throw new IllegalArgumentException("Found more than one matching task for data frame transform [" + request.getId()
|
||||
+ "] when " + "there should only be one.");
|
||||
}
|
||||
matchingTask = (DataFrameTransformTask) task;
|
||||
}
|
||||
}
|
||||
|
||||
if (matchingTask != null) {
|
||||
operation.accept(matchingTask);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, StartDataFrameTransformTaskAction.Request request,
|
||||
ActionListener<StartDataFrameTransformTaskAction.Response> listener) {
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.dataframe.persistence;
|
||||
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
|
||||
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
|
||||
|
||||
public final class DataFramePersistentTaskUtils {
|
||||
|
||||
private DataFramePersistentTaskUtils() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Check to see if the PersistentTask's cluster state contains the data frame transform(s) we
|
||||
* are interested in
|
||||
*/
|
||||
public static boolean stateHasDataFrameTransforms(String id, ClusterState state) {
|
||||
boolean hasTransforms = false;
|
||||
PersistentTasksCustomMetaData pTasksMeta = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
|
||||
|
||||
if (pTasksMeta != null) {
|
||||
// If the request was for _all transforms, we need to look through the list of
|
||||
// persistent tasks and see if at least one is a data frame task
|
||||
if (id.equals(MetaData.ALL)) {
|
||||
hasTransforms = pTasksMeta.tasks().stream()
|
||||
.anyMatch(persistentTask -> persistentTask.getTaskName().equals(DataFrameField.TASK_NAME));
|
||||
|
||||
} else if (pTasksMeta.getTask(id) != null) {
|
||||
// If we're looking for a single transform, we can just check directly
|
||||
hasTransforms = true;
|
||||
}
|
||||
}
|
||||
return hasTransforms;
|
||||
}
|
||||
}
|
|
@ -1,186 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.dataframe.util;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.ClearScrollRequest;
|
||||
import org.elasticsearch.action.search.ClearScrollResponse;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.search.SearchScrollRequest;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Provides basic tools around scrolling over documents and gathering the data in some Collection
|
||||
* @param <T> The object type that is being collected
|
||||
* @param <E> The collection that should be used (i.e. Set, Deque, etc.)
|
||||
*/
|
||||
public abstract class BatchedDataIterator<T, E extends Collection<T>> {
|
||||
private static final Logger LOGGER = LogManager.getLogger(BatchedDataIterator.class);
|
||||
|
||||
private static final String CONTEXT_ALIVE_DURATION = "5m";
|
||||
private static final int BATCH_SIZE = 10_000;
|
||||
|
||||
private final Client client;
|
||||
private final String index;
|
||||
private volatile long count;
|
||||
private volatile long totalHits;
|
||||
private volatile String scrollId;
|
||||
private volatile boolean isScrollInitialised;
|
||||
|
||||
protected BatchedDataIterator(Client client, String index) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.index = Objects.requireNonNull(index);
|
||||
this.totalHits = 0;
|
||||
this.count = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns {@code true} if the iteration has more elements.
|
||||
* (In other words, returns {@code true} if {@link #next} would
|
||||
* return an element rather than throwing an exception.)
|
||||
*
|
||||
* @return {@code true} if the iteration has more elements
|
||||
*/
|
||||
public boolean hasNext() {
|
||||
return !isScrollInitialised || count != totalHits;
|
||||
}
|
||||
|
||||
/**
|
||||
* The first time next() is called, the search will be performed and the first
|
||||
* batch will be given to the listener. Any subsequent call will return the following batches.
|
||||
* <p>
|
||||
* Note that in some implementations it is possible that when there are no
|
||||
* results at all. {@link BatchedDataIterator#hasNext()} will return {@code true} the first time it is called but then a call
|
||||
* to this function returns an empty Collection to the listener.
|
||||
*/
|
||||
public void next(ActionListener<E> listener) {
|
||||
if (!hasNext()) {
|
||||
listener.onFailure(new NoSuchElementException());
|
||||
}
|
||||
|
||||
if (!isScrollInitialised) {
|
||||
ActionListener<SearchResponse> wrappedListener = ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
isScrollInitialised = true;
|
||||
totalHits = searchResponse.getHits().getTotalHits().value;
|
||||
scrollId = searchResponse.getScrollId();
|
||||
mapHits(searchResponse, listener);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
initScroll(wrappedListener);
|
||||
} else {
|
||||
ActionListener<SearchResponse> wrappedListener = ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
scrollId = searchResponse.getScrollId();
|
||||
mapHits(searchResponse, listener);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollId).scroll(CONTEXT_ALIVE_DURATION);
|
||||
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ClientHelper.DATA_FRAME_ORIGIN,
|
||||
searchScrollRequest,
|
||||
wrappedListener,
|
||||
client::searchScroll);
|
||||
}
|
||||
}
|
||||
|
||||
private void initScroll(ActionListener<SearchResponse> listener) {
|
||||
LOGGER.trace("ES API CALL: search index {}", index);
|
||||
|
||||
SearchRequest searchRequest = new SearchRequest(index);
|
||||
searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
|
||||
searchRequest.scroll(CONTEXT_ALIVE_DURATION);
|
||||
searchRequest.source(new SearchSourceBuilder()
|
||||
.fetchSource(getFetchSourceContext())
|
||||
.size(getBatchSize())
|
||||
.query(getQuery())
|
||||
.trackTotalHits(true)
|
||||
.sort(sortField(), sortOrder()));
|
||||
|
||||
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ClientHelper.DATA_FRAME_ORIGIN,
|
||||
searchRequest,
|
||||
listener,
|
||||
client::search);
|
||||
}
|
||||
|
||||
private void mapHits(SearchResponse searchResponse, ActionListener<E> mappingListener) {
|
||||
E results = getCollection();
|
||||
|
||||
SearchHit[] hits = searchResponse.getHits().getHits();
|
||||
for (SearchHit hit : hits) {
|
||||
T mapped = map(hit);
|
||||
if (mapped != null) {
|
||||
results.add(mapped);
|
||||
}
|
||||
}
|
||||
count += hits.length;
|
||||
|
||||
if (!hasNext() && scrollId != null) {
|
||||
ClearScrollRequest request = client.prepareClearScroll().setScrollIds(Collections.singletonList(scrollId)).request();
|
||||
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ClientHelper.DATA_FRAME_ORIGIN,
|
||||
request,
|
||||
ActionListener.<ClearScrollResponse>wrap(
|
||||
r -> mappingListener.onResponse(results),
|
||||
mappingListener::onFailure
|
||||
),
|
||||
client::clearScroll);
|
||||
} else {
|
||||
mappingListener.onResponse(results);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the query to use for the search
|
||||
* @return the search query
|
||||
*/
|
||||
protected abstract QueryBuilder getQuery();
|
||||
|
||||
/**
|
||||
* Maps the search hit to the document type
|
||||
* @param hit the search hit
|
||||
* @return The mapped document or {@code null} if the mapping failed
|
||||
*/
|
||||
protected abstract T map(SearchHit hit);
|
||||
|
||||
protected abstract E getCollection();
|
||||
|
||||
protected abstract SortOrder sortOrder();
|
||||
|
||||
protected abstract String sortField();
|
||||
|
||||
/**
|
||||
* Should we fetch the source and what fields specifically.
|
||||
*
|
||||
* Defaults to all fields and true.
|
||||
*/
|
||||
protected FetchSourceContext getFetchSourceContext() {
|
||||
return FetchSourceContext.FETCH_SOURCE;
|
||||
}
|
||||
|
||||
protected int getBatchSize() {
|
||||
return BATCH_SIZE;
|
||||
}
|
||||
}
|
|
@ -1,125 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.dataframe.util;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* A utility that allows chained (serial) execution of a number of tasks
|
||||
* in async manner.
|
||||
*/
|
||||
public class TypedChainTaskExecutor<T> {
|
||||
|
||||
public interface ChainTask <T> {
|
||||
void run(ActionListener<T> listener);
|
||||
}
|
||||
|
||||
private final ExecutorService executorService;
|
||||
private final LinkedList<ChainTask<T>> tasks = new LinkedList<>();
|
||||
private final Predicate<Exception> failureShortCircuitPredicate;
|
||||
private final Predicate<T> continuationPredicate;
|
||||
private final List<T> collectedResponses;
|
||||
|
||||
/**
|
||||
* Creates a new TypedChainTaskExecutor.
|
||||
* Each chainedTask is executed in order serially and after each execution the continuationPredicate is tested.
|
||||
*
|
||||
* On failures the failureShortCircuitPredicate is tested.
|
||||
*
|
||||
* @param executorService The service where to execute the tasks
|
||||
* @param continuationPredicate The predicate to test on whether to execute the next task or not.
|
||||
* {@code true} means continue on to the next task.
|
||||
* Must be able to handle null values.
|
||||
* @param failureShortCircuitPredicate The predicate on whether to short circuit execution on a give exception.
|
||||
* {@code true} means that no more tasks should execute and the the listener::onFailure should be
|
||||
* called.
|
||||
*/
|
||||
public TypedChainTaskExecutor(ExecutorService executorService,
|
||||
Predicate<T> continuationPredicate,
|
||||
Predicate<Exception> failureShortCircuitPredicate) {
|
||||
this.executorService = Objects.requireNonNull(executorService);
|
||||
this.continuationPredicate = continuationPredicate;
|
||||
this.failureShortCircuitPredicate = failureShortCircuitPredicate;
|
||||
this.collectedResponses = new ArrayList<>();
|
||||
}
|
||||
|
||||
public synchronized void add(ChainTask<T> task) {
|
||||
tasks.add(task);
|
||||
}
|
||||
|
||||
private synchronized void execute(T previousValue, ActionListener<List<T>> listener) {
|
||||
collectedResponses.add(previousValue);
|
||||
if (continuationPredicate.test(previousValue)) {
|
||||
if (tasks.isEmpty()) {
|
||||
listener.onResponse(Collections.unmodifiableList(new ArrayList<>(collectedResponses)));
|
||||
return;
|
||||
}
|
||||
ChainTask<T> task = tasks.pop();
|
||||
executorService.execute(new AbstractRunnable() {
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
if (failureShortCircuitPredicate.test(e)) {
|
||||
listener.onFailure(e);
|
||||
} else {
|
||||
execute(null, listener);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doRun() {
|
||||
task.run(ActionListener.wrap(value -> execute(value, listener), this::onFailure));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
listener.onResponse(Collections.unmodifiableList(new ArrayList<>(collectedResponses)));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute all the chained tasks serially, notify listener when completed
|
||||
*
|
||||
* @param listener The ActionListener to notify when all executions have been completed,
|
||||
* or when no further tasks should be executed.
|
||||
* The resulting list COULD contain null values depending on if execution is continued
|
||||
* on exceptions or not.
|
||||
*/
|
||||
public synchronized void execute(ActionListener<List<T>> listener) {
|
||||
if (tasks.isEmpty()) {
|
||||
listener.onResponse(Collections.emptyList());
|
||||
return;
|
||||
}
|
||||
collectedResponses.clear();
|
||||
ChainTask<T> task = tasks.pop();
|
||||
executorService.execute(new AbstractRunnable() {
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
if (failureShortCircuitPredicate.test(e)) {
|
||||
listener.onFailure(e);
|
||||
} else {
|
||||
execute(null, listener);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doRun() {
|
||||
task.run(ActionListener.wrap(value -> execute(value, listener), this::onFailure));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public synchronized List<T> getCollectedResponses() {
|
||||
return Collections.unmodifiableList(new ArrayList<>(collectedResponses));
|
||||
}
|
||||
}
|
|
@ -1,329 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.dataframe.util;
|
||||
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.ClearScrollRequest;
|
||||
import org.elasticsearch.action.search.ClearScrollRequestBuilder;
|
||||
import org.elasticsearch.action.search.ClearScrollResponse;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.search.SearchScrollRequest;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.document.DocumentField;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.search.internal.SearchContext;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.junit.Before;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Deque;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class BatchedDataIteratorTests extends ESTestCase {
|
||||
|
||||
private static final String INDEX_NAME = "some_index_name";
|
||||
private static final String SCROLL_ID = "someScrollId";
|
||||
|
||||
private Client client;
|
||||
private boolean wasScrollCleared;
|
||||
|
||||
private TestIterator testIterator;
|
||||
|
||||
private List<SearchRequest> searchRequestCaptor = new ArrayList<>();
|
||||
private List<SearchScrollRequest> searchScrollRequestCaptor = new ArrayList<>();
|
||||
|
||||
@Before
|
||||
public void setUpMocks() {
|
||||
ThreadPool pool = mock(ThreadPool.class);
|
||||
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
|
||||
when(pool.getThreadContext()).thenReturn(threadContext);
|
||||
client = Mockito.mock(Client.class);
|
||||
when(client.threadPool()).thenReturn(pool);
|
||||
wasScrollCleared = false;
|
||||
testIterator = new TestIterator(client, INDEX_NAME);
|
||||
givenClearScrollRequest();
|
||||
searchRequestCaptor.clear();
|
||||
searchScrollRequestCaptor.clear();
|
||||
}
|
||||
|
||||
public void testQueryReturnsNoResults() throws Exception {
|
||||
new ScrollResponsesMocker().finishMock();
|
||||
|
||||
assertTrue(testIterator.hasNext());
|
||||
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
|
||||
testIterator.next(future);
|
||||
assertTrue(future.get().isEmpty());
|
||||
assertFalse(testIterator.hasNext());
|
||||
assertTrue(wasScrollCleared);
|
||||
assertSearchRequest();
|
||||
assertSearchScrollRequests(0);
|
||||
}
|
||||
|
||||
public void testCallingNextWhenHasNextIsFalseThrows() throws Exception {
|
||||
PlainActionFuture<Deque<String>> firstFuture = new PlainActionFuture<>();
|
||||
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
|
||||
testIterator.next(firstFuture);
|
||||
firstFuture.get();
|
||||
assertFalse(testIterator.hasNext());
|
||||
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
|
||||
ExecutionException executionException = ESTestCase.expectThrows(ExecutionException.class, () -> {
|
||||
testIterator.next(future);
|
||||
future.get();
|
||||
});
|
||||
assertNotNull(executionException.getCause());
|
||||
assertTrue(executionException.getCause() instanceof NoSuchElementException);
|
||||
}
|
||||
|
||||
public void testQueryReturnsSingleBatch() throws Exception {
|
||||
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
|
||||
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
|
||||
|
||||
assertTrue(testIterator.hasNext());
|
||||
testIterator.next(future);
|
||||
Deque<String> batch = future.get();
|
||||
assertEquals(3, batch.size());
|
||||
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
|
||||
assertFalse(testIterator.hasNext());
|
||||
assertTrue(wasScrollCleared);
|
||||
|
||||
assertSearchRequest();
|
||||
assertSearchScrollRequests(0);
|
||||
}
|
||||
|
||||
public void testQueryReturnsThreeBatches() throws Exception {
|
||||
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
|
||||
new ScrollResponsesMocker()
|
||||
.addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
|
||||
.addBatch(createJsonDoc("d"), createJsonDoc("e"))
|
||||
.addBatch(createJsonDoc("f"))
|
||||
.finishMock();
|
||||
|
||||
assertTrue(testIterator.hasNext());
|
||||
|
||||
testIterator.next(future);
|
||||
Deque<String> batch = future.get();
|
||||
assertEquals(3, batch.size());
|
||||
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
|
||||
|
||||
future = new PlainActionFuture<>();
|
||||
testIterator.next(future);
|
||||
batch = future.get();
|
||||
assertEquals(2, batch.size());
|
||||
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("d"), createJsonDoc("e"))));
|
||||
|
||||
future = new PlainActionFuture<>();
|
||||
testIterator.next(future);
|
||||
batch = future.get();
|
||||
assertEquals(1, batch.size());
|
||||
assertTrue(batch.contains(createJsonDoc("f")));
|
||||
|
||||
assertFalse(testIterator.hasNext());
|
||||
assertTrue(wasScrollCleared);
|
||||
|
||||
assertSearchRequest();
|
||||
assertSearchScrollRequests(2);
|
||||
}
|
||||
|
||||
private String createJsonDoc(String value) {
|
||||
return "{\"foo\":\"" + value + "\"}";
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void givenClearScrollRequest() {
|
||||
ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class);
|
||||
|
||||
when(client.prepareClearScroll()).thenReturn(requestBuilder);
|
||||
when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder);
|
||||
ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
|
||||
clearScrollRequest.addScrollId(SCROLL_ID);
|
||||
when(requestBuilder.request()).thenReturn(clearScrollRequest);
|
||||
doAnswer((answer) -> {
|
||||
wasScrollCleared = true;
|
||||
ActionListener<ClearScrollResponse> scrollListener =
|
||||
(ActionListener<ClearScrollResponse>) answer.getArguments()[1];
|
||||
scrollListener.onResponse(new ClearScrollResponse(true,0));
|
||||
return null;
|
||||
}).when(client).clearScroll(any(ClearScrollRequest.class), any(ActionListener.class));
|
||||
}
|
||||
|
||||
private void assertSearchRequest() {
|
||||
List<SearchRequest> searchRequests = searchRequestCaptor;
|
||||
assertThat(searchRequests.size(), equalTo(1));
|
||||
SearchRequest searchRequest = searchRequests.get(0);
|
||||
assertThat(searchRequest.indices(), equalTo(new String[] {INDEX_NAME}));
|
||||
assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
|
||||
assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
|
||||
assertThat(searchRequest.source().trackTotalHitsUpTo(), is(SearchContext.TRACK_TOTAL_HITS_ACCURATE));
|
||||
}
|
||||
|
||||
private void assertSearchScrollRequests(int expectedCount) {
|
||||
List<SearchScrollRequest> searchScrollRequests = searchScrollRequestCaptor;
|
||||
assertThat(searchScrollRequests.size(), equalTo(expectedCount));
|
||||
for (SearchScrollRequest request : searchScrollRequests) {
|
||||
assertThat(request.scrollId(), equalTo(SCROLL_ID));
|
||||
assertThat(request.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
|
||||
}
|
||||
}
|
||||
|
||||
private class ScrollResponsesMocker {
|
||||
private List<String[]> batches = new ArrayList<>();
|
||||
private long totalHits = 0;
|
||||
private List<SearchResponse> responses = new ArrayList<>();
|
||||
|
||||
ScrollResponsesMocker addBatch(String... hits) {
|
||||
totalHits += hits.length;
|
||||
batches.add(hits);
|
||||
return this;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
void finishMock() {
|
||||
if (batches.isEmpty()) {
|
||||
givenInitialResponse();
|
||||
return;
|
||||
}
|
||||
givenInitialResponse(batches.get(0));
|
||||
for (int i = 1; i < batches.size(); ++i) {
|
||||
givenNextResponse(batches.get(i));
|
||||
}
|
||||
if (responses.size() > 0) {
|
||||
SearchResponse first = responses.get(0);
|
||||
if (responses.size() > 1) {
|
||||
List<SearchResponse> rest = new ArrayList<>(responses);
|
||||
Iterator<SearchResponse> responseIterator = rest.iterator();
|
||||
doAnswer((answer) -> {
|
||||
SearchScrollRequest request = (SearchScrollRequest)answer.getArguments()[0];
|
||||
ActionListener<SearchResponse> rsp = (ActionListener<SearchResponse>)answer.getArguments()[1];
|
||||
searchScrollRequestCaptor.add(request);
|
||||
rsp.onResponse(responseIterator.next());
|
||||
return null;
|
||||
}).when(client).searchScroll(any(SearchScrollRequest.class), any(ActionListener.class));
|
||||
} else {
|
||||
doAnswer((answer) -> {
|
||||
SearchScrollRequest request = (SearchScrollRequest)answer.getArguments()[0];
|
||||
ActionListener<SearchResponse> rsp = (ActionListener<SearchResponse>)answer.getArguments()[1];
|
||||
searchScrollRequestCaptor.add(request);
|
||||
rsp.onResponse(first);
|
||||
return null;
|
||||
}).when(client).searchScroll(any(SearchScrollRequest.class), any(ActionListener.class));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void givenInitialResponse(String... hits) {
|
||||
SearchResponse searchResponse = createSearchResponseWithHits(hits);
|
||||
doAnswer((answer) -> {
|
||||
SearchRequest request = (SearchRequest)answer.getArguments()[0];
|
||||
searchRequestCaptor.add(request);
|
||||
ActionListener<SearchResponse> rsp = (ActionListener<SearchResponse>)answer.getArguments()[1];
|
||||
rsp.onResponse(searchResponse);
|
||||
return null;
|
||||
}).when(client).search(any(SearchRequest.class), any(ActionListener.class));
|
||||
}
|
||||
|
||||
private void givenNextResponse(String... hits) {
|
||||
responses.add(createSearchResponseWithHits(hits));
|
||||
}
|
||||
|
||||
private SearchResponse createSearchResponseWithHits(String... hits) {
|
||||
SearchHits searchHits = createHits(hits);
|
||||
SearchResponse searchResponse = mock(SearchResponse.class);
|
||||
when(searchResponse.getScrollId()).thenReturn(SCROLL_ID);
|
||||
when(searchResponse.getHits()).thenReturn(searchHits);
|
||||
return searchResponse;
|
||||
}
|
||||
|
||||
private SearchHits createHits(String... values) {
|
||||
List<SearchHit> hits = new ArrayList<>();
|
||||
for (String value : values) {
|
||||
hits.add(new SearchHitBuilder(randomInt()).setSource(value).build());
|
||||
}
|
||||
return new SearchHits(hits.toArray(new SearchHit[hits.size()]), new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), 1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
private static class TestIterator extends BatchedDataIterator<String, Deque<String>> {
|
||||
TestIterator(Client client, String jobId) {
|
||||
super(client, jobId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryBuilder getQuery() {
|
||||
return QueryBuilders.matchAllQuery();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String map(SearchHit hit) {
|
||||
return hit.getSourceAsString();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deque<String> getCollection() {
|
||||
return new ArrayDeque<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SortOrder sortOrder() {
|
||||
return SortOrder.DESC;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String sortField() {
|
||||
return "foo";
|
||||
}
|
||||
}
|
||||
public class SearchHitBuilder {
|
||||
|
||||
private final SearchHit hit;
|
||||
private final Map<String, DocumentField> fields;
|
||||
|
||||
public SearchHitBuilder(int docId) {
|
||||
hit = new SearchHit(docId);
|
||||
fields = new HashMap<>();
|
||||
}
|
||||
|
||||
public SearchHitBuilder setSource(String sourceJson) {
|
||||
hit.sourceRef(new BytesArray(sourceJson));
|
||||
return this;
|
||||
}
|
||||
|
||||
public SearchHit build() {
|
||||
if (!fields.isEmpty()) {
|
||||
hit.fields(fields);
|
||||
}
|
||||
return hit;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -118,3 +118,64 @@ teardown:
|
|||
transform_id: "airline-transform-start-stop"
|
||||
- match: { stopped: true }
|
||||
|
||||
---
|
||||
"Test start/stop only starts/stops specified transform":
|
||||
- do:
|
||||
data_frame.put_data_frame_transform:
|
||||
transform_id: "airline-transform-start-later"
|
||||
body: >
|
||||
{
|
||||
"source": { "index": "airline-data" },
|
||||
"dest": { "index": "airline-data-start-later" },
|
||||
"pivot": {
|
||||
"group_by": { "airline": {"terms": {"field": "airline"}}},
|
||||
"aggs": {"avg_response": {"avg": {"field": "responsetime"}}}
|
||||
}
|
||||
}
|
||||
- do:
|
||||
data_frame.start_data_frame_transform:
|
||||
transform_id: "airline-transform-start-stop"
|
||||
- match: { started: true }
|
||||
|
||||
- do:
|
||||
data_frame.get_data_frame_transform_stats:
|
||||
transform_id: "airline-transform-start-stop"
|
||||
- match: { count: 1 }
|
||||
- match: { transforms.0.id: "airline-transform-start-stop" }
|
||||
- match: { transforms.0.state.indexer_state: "started" }
|
||||
- match: { transforms.0.state.task_state: "started" }
|
||||
|
||||
- do:
|
||||
data_frame.get_data_frame_transform_stats:
|
||||
transform_id: "airline-transform-start-later"
|
||||
- match: { count: 1 }
|
||||
- match: { transforms.0.id: "airline-transform-start-later" }
|
||||
- match: { transforms.0.state.indexer_state: "stopped" }
|
||||
- match: { transforms.0.state.task_state: "stopped" }
|
||||
|
||||
- do:
|
||||
data_frame.start_data_frame_transform:
|
||||
transform_id: "airline-transform-start-later"
|
||||
- match: { started: true }
|
||||
|
||||
- do:
|
||||
data_frame.stop_data_frame_transform:
|
||||
transform_id: "airline-transform-start-stop"
|
||||
- match: { stopped: true }
|
||||
|
||||
- do:
|
||||
data_frame.get_data_frame_transform_stats:
|
||||
transform_id: "airline-transform-start-later"
|
||||
- match: { count: 1 }
|
||||
- match: { transforms.0.id: "airline-transform-start-later" }
|
||||
- match: { transforms.0.state.indexer_state: "started" }
|
||||
- match: { transforms.0.state.task_state: "started" }
|
||||
|
||||
- do:
|
||||
data_frame.stop_data_frame_transform:
|
||||
transform_id: "airline-transform-start-later"
|
||||
- match: { stopped: true }
|
||||
|
||||
- do:
|
||||
data_frame.delete_data_frame_transform:
|
||||
transform_id: "airline-transform-start-later"
|
||||
|
|
Loading…
Reference in New Issue