[ML] refactoring start task a bit, removing unused code (#40798) (#40845)

This commit is contained in:
Benjamin Trent 2019-04-05 09:01:01 -05:00 committed by GitHub
parent 922a70ce32
commit 665f0d81aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 67 additions and 703 deletions

View File

@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -62,6 +63,11 @@ public class StartDataFrameTransformTaskAction extends Action<StartDataFrameTran
out.writeString(id);
}
@Override
public boolean match(Task task) {
return task.getDescription().equals(DataFrameField.PERSISTENT_TASK_DESCRIPTION_PREFIX + id);
}
@Override
public ActionRequestValidationException validate() {
return null;

View File

@ -24,7 +24,6 @@ import org.elasticsearch.xpack.core.dataframe.action.StartDataFrameTransformTask
import org.elasticsearch.xpack.dataframe.transforms.DataFrameTransformTask;
import java.util.List;
import java.util.function.Consumer;
/**
* Internal only transport class to change an allocated persistent task's state to started
@ -44,27 +43,6 @@ public class TransportStartDataFrameTransformTaskAction extends
this.licenseState = licenseState;
}
@Override
protected void processTasks(StartDataFrameTransformTaskAction.Request request, Consumer<DataFrameTransformTask> operation) {
DataFrameTransformTask matchingTask = null;
// todo: re-factor, see rollup TransportTaskHelper
for (Task task : taskManager.getTasks().values()) {
if (task instanceof DataFrameTransformTask
&& ((DataFrameTransformTask) task).getTransformId().equals(request.getId())) {
if (matchingTask != null) {
throw new IllegalArgumentException("Found more than one matching task for data frame transform [" + request.getId()
+ "] when " + "there should only be one.");
}
matchingTask = (DataFrameTransformTask) task;
}
}
if (matchingTask != null) {
operation.accept(matchingTask);
}
}
@Override
protected void doExecute(Task task, StartDataFrameTransformTaskAction.Request request,
ActionListener<StartDataFrameTransformTaskAction.Response> listener) {

View File

@ -1,41 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.dataframe.persistence;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
public final class DataFramePersistentTaskUtils {
private DataFramePersistentTaskUtils() {
}
/**
* Check to see if the PersistentTask's cluster state contains the data frame transform(s) we
* are interested in
*/
public static boolean stateHasDataFrameTransforms(String id, ClusterState state) {
boolean hasTransforms = false;
PersistentTasksCustomMetaData pTasksMeta = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
if (pTasksMeta != null) {
// If the request was for _all transforms, we need to look through the list of
// persistent tasks and see if at least one is a data frame task
if (id.equals(MetaData.ALL)) {
hasTransforms = pTasksMeta.tasks().stream()
.anyMatch(persistentTask -> persistentTask.getTaskName().equals(DataFrameField.TASK_NAME));
} else if (pTasksMeta.getTask(id) != null) {
// If we're looking for a single transform, we can just check directly
hasTransforms = true;
}
}
return hasTransforms;
}
}

View File

@ -1,186 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.dataframe.util;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.ClearScrollResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import java.util.Collection;
import java.util.Collections;
import java.util.NoSuchElementException;
import java.util.Objects;
/**
* Provides basic tools around scrolling over documents and gathering the data in some Collection
* @param <T> The object type that is being collected
* @param <E> The collection that should be used (i.e. Set, Deque, etc.)
*/
public abstract class BatchedDataIterator<T, E extends Collection<T>> {
private static final Logger LOGGER = LogManager.getLogger(BatchedDataIterator.class);
private static final String CONTEXT_ALIVE_DURATION = "5m";
private static final int BATCH_SIZE = 10_000;
private final Client client;
private final String index;
private volatile long count;
private volatile long totalHits;
private volatile String scrollId;
private volatile boolean isScrollInitialised;
protected BatchedDataIterator(Client client, String index) {
this.client = Objects.requireNonNull(client);
this.index = Objects.requireNonNull(index);
this.totalHits = 0;
this.count = 0;
}
/**
* Returns {@code true} if the iteration has more elements.
* (In other words, returns {@code true} if {@link #next} would
* return an element rather than throwing an exception.)
*
* @return {@code true} if the iteration has more elements
*/
public boolean hasNext() {
return !isScrollInitialised || count != totalHits;
}
/**
* The first time next() is called, the search will be performed and the first
* batch will be given to the listener. Any subsequent call will return the following batches.
* <p>
* Note that in some implementations it is possible that when there are no
* results at all. {@link BatchedDataIterator#hasNext()} will return {@code true} the first time it is called but then a call
* to this function returns an empty Collection to the listener.
*/
public void next(ActionListener<E> listener) {
if (!hasNext()) {
listener.onFailure(new NoSuchElementException());
}
if (!isScrollInitialised) {
ActionListener<SearchResponse> wrappedListener = ActionListener.wrap(
searchResponse -> {
isScrollInitialised = true;
totalHits = searchResponse.getHits().getTotalHits().value;
scrollId = searchResponse.getScrollId();
mapHits(searchResponse, listener);
},
listener::onFailure
);
initScroll(wrappedListener);
} else {
ActionListener<SearchResponse> wrappedListener = ActionListener.wrap(
searchResponse -> {
scrollId = searchResponse.getScrollId();
mapHits(searchResponse, listener);
},
listener::onFailure
);
SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollId).scroll(CONTEXT_ALIVE_DURATION);
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ClientHelper.DATA_FRAME_ORIGIN,
searchScrollRequest,
wrappedListener,
client::searchScroll);
}
}
private void initScroll(ActionListener<SearchResponse> listener) {
LOGGER.trace("ES API CALL: search index {}", index);
SearchRequest searchRequest = new SearchRequest(index);
searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
searchRequest.scroll(CONTEXT_ALIVE_DURATION);
searchRequest.source(new SearchSourceBuilder()
.fetchSource(getFetchSourceContext())
.size(getBatchSize())
.query(getQuery())
.trackTotalHits(true)
.sort(sortField(), sortOrder()));
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ClientHelper.DATA_FRAME_ORIGIN,
searchRequest,
listener,
client::search);
}
private void mapHits(SearchResponse searchResponse, ActionListener<E> mappingListener) {
E results = getCollection();
SearchHit[] hits = searchResponse.getHits().getHits();
for (SearchHit hit : hits) {
T mapped = map(hit);
if (mapped != null) {
results.add(mapped);
}
}
count += hits.length;
if (!hasNext() && scrollId != null) {
ClearScrollRequest request = client.prepareClearScroll().setScrollIds(Collections.singletonList(scrollId)).request();
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ClientHelper.DATA_FRAME_ORIGIN,
request,
ActionListener.<ClearScrollResponse>wrap(
r -> mappingListener.onResponse(results),
mappingListener::onFailure
),
client::clearScroll);
} else {
mappingListener.onResponse(results);
}
}
/**
* Get the query to use for the search
* @return the search query
*/
protected abstract QueryBuilder getQuery();
/**
* Maps the search hit to the document type
* @param hit the search hit
* @return The mapped document or {@code null} if the mapping failed
*/
protected abstract T map(SearchHit hit);
protected abstract E getCollection();
protected abstract SortOrder sortOrder();
protected abstract String sortField();
/**
* Should we fetch the source and what fields specifically.
*
* Defaults to all fields and true.
*/
protected FetchSourceContext getFetchSourceContext() {
return FetchSourceContext.FETCH_SOURCE;
}
protected int getBatchSize() {
return BATCH_SIZE;
}
}

View File

@ -1,125 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.dataframe.util;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.function.Predicate;
/**
* A utility that allows chained (serial) execution of a number of tasks
* in async manner.
*/
public class TypedChainTaskExecutor<T> {
public interface ChainTask <T> {
void run(ActionListener<T> listener);
}
private final ExecutorService executorService;
private final LinkedList<ChainTask<T>> tasks = new LinkedList<>();
private final Predicate<Exception> failureShortCircuitPredicate;
private final Predicate<T> continuationPredicate;
private final List<T> collectedResponses;
/**
* Creates a new TypedChainTaskExecutor.
* Each chainedTask is executed in order serially and after each execution the continuationPredicate is tested.
*
* On failures the failureShortCircuitPredicate is tested.
*
* @param executorService The service where to execute the tasks
* @param continuationPredicate The predicate to test on whether to execute the next task or not.
* {@code true} means continue on to the next task.
* Must be able to handle null values.
* @param failureShortCircuitPredicate The predicate on whether to short circuit execution on a give exception.
* {@code true} means that no more tasks should execute and the the listener::onFailure should be
* called.
*/
public TypedChainTaskExecutor(ExecutorService executorService,
Predicate<T> continuationPredicate,
Predicate<Exception> failureShortCircuitPredicate) {
this.executorService = Objects.requireNonNull(executorService);
this.continuationPredicate = continuationPredicate;
this.failureShortCircuitPredicate = failureShortCircuitPredicate;
this.collectedResponses = new ArrayList<>();
}
public synchronized void add(ChainTask<T> task) {
tasks.add(task);
}
private synchronized void execute(T previousValue, ActionListener<List<T>> listener) {
collectedResponses.add(previousValue);
if (continuationPredicate.test(previousValue)) {
if (tasks.isEmpty()) {
listener.onResponse(Collections.unmodifiableList(new ArrayList<>(collectedResponses)));
return;
}
ChainTask<T> task = tasks.pop();
executorService.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
if (failureShortCircuitPredicate.test(e)) {
listener.onFailure(e);
} else {
execute(null, listener);
}
}
@Override
protected void doRun() {
task.run(ActionListener.wrap(value -> execute(value, listener), this::onFailure));
}
});
} else {
listener.onResponse(Collections.unmodifiableList(new ArrayList<>(collectedResponses)));
}
}
/**
* Execute all the chained tasks serially, notify listener when completed
*
* @param listener The ActionListener to notify when all executions have been completed,
* or when no further tasks should be executed.
* The resulting list COULD contain null values depending on if execution is continued
* on exceptions or not.
*/
public synchronized void execute(ActionListener<List<T>> listener) {
if (tasks.isEmpty()) {
listener.onResponse(Collections.emptyList());
return;
}
collectedResponses.clear();
ChainTask<T> task = tasks.pop();
executorService.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
if (failureShortCircuitPredicate.test(e)) {
listener.onFailure(e);
} else {
execute(null, listener);
}
}
@Override
protected void doRun() {
task.run(ActionListener.wrap(value -> execute(value, listener), this::onFailure));
}
});
}
public synchronized List<T> getCollectedResponses() {
return Collections.unmodifiableList(new ArrayList<>(collectedResponses));
}
}

View File

@ -1,329 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.dataframe.util;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.ClearScrollRequestBuilder;
import org.elasticsearch.action.search.ClearScrollResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.Before;
import org.mockito.Mockito;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.ExecutionException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class BatchedDataIteratorTests extends ESTestCase {
private static final String INDEX_NAME = "some_index_name";
private static final String SCROLL_ID = "someScrollId";
private Client client;
private boolean wasScrollCleared;
private TestIterator testIterator;
private List<SearchRequest> searchRequestCaptor = new ArrayList<>();
private List<SearchScrollRequest> searchScrollRequestCaptor = new ArrayList<>();
@Before
public void setUpMocks() {
ThreadPool pool = mock(ThreadPool.class);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(pool.getThreadContext()).thenReturn(threadContext);
client = Mockito.mock(Client.class);
when(client.threadPool()).thenReturn(pool);
wasScrollCleared = false;
testIterator = new TestIterator(client, INDEX_NAME);
givenClearScrollRequest();
searchRequestCaptor.clear();
searchScrollRequestCaptor.clear();
}
public void testQueryReturnsNoResults() throws Exception {
new ScrollResponsesMocker().finishMock();
assertTrue(testIterator.hasNext());
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
testIterator.next(future);
assertTrue(future.get().isEmpty());
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(0);
}
public void testCallingNextWhenHasNextIsFalseThrows() throws Exception {
PlainActionFuture<Deque<String>> firstFuture = new PlainActionFuture<>();
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
testIterator.next(firstFuture);
firstFuture.get();
assertFalse(testIterator.hasNext());
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
ExecutionException executionException = ESTestCase.expectThrows(ExecutionException.class, () -> {
testIterator.next(future);
future.get();
});
assertNotNull(executionException.getCause());
assertTrue(executionException.getCause() instanceof NoSuchElementException);
}
public void testQueryReturnsSingleBatch() throws Exception {
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
assertTrue(testIterator.hasNext());
testIterator.next(future);
Deque<String> batch = future.get();
assertEquals(3, batch.size());
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(0);
}
public void testQueryReturnsThreeBatches() throws Exception {
PlainActionFuture<Deque<String>> future = new PlainActionFuture<>();
new ScrollResponsesMocker()
.addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
.addBatch(createJsonDoc("d"), createJsonDoc("e"))
.addBatch(createJsonDoc("f"))
.finishMock();
assertTrue(testIterator.hasNext());
testIterator.next(future);
Deque<String> batch = future.get();
assertEquals(3, batch.size());
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))));
future = new PlainActionFuture<>();
testIterator.next(future);
batch = future.get();
assertEquals(2, batch.size());
assertTrue(batch.containsAll(Arrays.asList(createJsonDoc("d"), createJsonDoc("e"))));
future = new PlainActionFuture<>();
testIterator.next(future);
batch = future.get();
assertEquals(1, batch.size());
assertTrue(batch.contains(createJsonDoc("f")));
assertFalse(testIterator.hasNext());
assertTrue(wasScrollCleared);
assertSearchRequest();
assertSearchScrollRequests(2);
}
private String createJsonDoc(String value) {
return "{\"foo\":\"" + value + "\"}";
}
@SuppressWarnings("unchecked")
private void givenClearScrollRequest() {
ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class);
when(client.prepareClearScroll()).thenReturn(requestBuilder);
when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder);
ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
clearScrollRequest.addScrollId(SCROLL_ID);
when(requestBuilder.request()).thenReturn(clearScrollRequest);
doAnswer((answer) -> {
wasScrollCleared = true;
ActionListener<ClearScrollResponse> scrollListener =
(ActionListener<ClearScrollResponse>) answer.getArguments()[1];
scrollListener.onResponse(new ClearScrollResponse(true,0));
return null;
}).when(client).clearScroll(any(ClearScrollRequest.class), any(ActionListener.class));
}
private void assertSearchRequest() {
List<SearchRequest> searchRequests = searchRequestCaptor;
assertThat(searchRequests.size(), equalTo(1));
SearchRequest searchRequest = searchRequests.get(0);
assertThat(searchRequest.indices(), equalTo(new String[] {INDEX_NAME}));
assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
assertThat(searchRequest.source().trackTotalHitsUpTo(), is(SearchContext.TRACK_TOTAL_HITS_ACCURATE));
}
private void assertSearchScrollRequests(int expectedCount) {
List<SearchScrollRequest> searchScrollRequests = searchScrollRequestCaptor;
assertThat(searchScrollRequests.size(), equalTo(expectedCount));
for (SearchScrollRequest request : searchScrollRequests) {
assertThat(request.scrollId(), equalTo(SCROLL_ID));
assertThat(request.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
}
}
private class ScrollResponsesMocker {
private List<String[]> batches = new ArrayList<>();
private long totalHits = 0;
private List<SearchResponse> responses = new ArrayList<>();
ScrollResponsesMocker addBatch(String... hits) {
totalHits += hits.length;
batches.add(hits);
return this;
}
@SuppressWarnings("unchecked")
void finishMock() {
if (batches.isEmpty()) {
givenInitialResponse();
return;
}
givenInitialResponse(batches.get(0));
for (int i = 1; i < batches.size(); ++i) {
givenNextResponse(batches.get(i));
}
if (responses.size() > 0) {
SearchResponse first = responses.get(0);
if (responses.size() > 1) {
List<SearchResponse> rest = new ArrayList<>(responses);
Iterator<SearchResponse> responseIterator = rest.iterator();
doAnswer((answer) -> {
SearchScrollRequest request = (SearchScrollRequest)answer.getArguments()[0];
ActionListener<SearchResponse> rsp = (ActionListener<SearchResponse>)answer.getArguments()[1];
searchScrollRequestCaptor.add(request);
rsp.onResponse(responseIterator.next());
return null;
}).when(client).searchScroll(any(SearchScrollRequest.class), any(ActionListener.class));
} else {
doAnswer((answer) -> {
SearchScrollRequest request = (SearchScrollRequest)answer.getArguments()[0];
ActionListener<SearchResponse> rsp = (ActionListener<SearchResponse>)answer.getArguments()[1];
searchScrollRequestCaptor.add(request);
rsp.onResponse(first);
return null;
}).when(client).searchScroll(any(SearchScrollRequest.class), any(ActionListener.class));
}
}
}
@SuppressWarnings("unchecked")
private void givenInitialResponse(String... hits) {
SearchResponse searchResponse = createSearchResponseWithHits(hits);
doAnswer((answer) -> {
SearchRequest request = (SearchRequest)answer.getArguments()[0];
searchRequestCaptor.add(request);
ActionListener<SearchResponse> rsp = (ActionListener<SearchResponse>)answer.getArguments()[1];
rsp.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), any(ActionListener.class));
}
private void givenNextResponse(String... hits) {
responses.add(createSearchResponseWithHits(hits));
}
private SearchResponse createSearchResponseWithHits(String... hits) {
SearchHits searchHits = createHits(hits);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getScrollId()).thenReturn(SCROLL_ID);
when(searchResponse.getHits()).thenReturn(searchHits);
return searchResponse;
}
private SearchHits createHits(String... values) {
List<SearchHit> hits = new ArrayList<>();
for (String value : values) {
hits.add(new SearchHitBuilder(randomInt()).setSource(value).build());
}
return new SearchHits(hits.toArray(new SearchHit[hits.size()]), new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), 1.0f);
}
}
private static class TestIterator extends BatchedDataIterator<String, Deque<String>> {
TestIterator(Client client, String jobId) {
super(client, jobId);
}
@Override
protected QueryBuilder getQuery() {
return QueryBuilders.matchAllQuery();
}
@Override
protected String map(SearchHit hit) {
return hit.getSourceAsString();
}
@Override
protected Deque<String> getCollection() {
return new ArrayDeque<>();
}
@Override
protected SortOrder sortOrder() {
return SortOrder.DESC;
}
@Override
protected String sortField() {
return "foo";
}
}
public class SearchHitBuilder {
private final SearchHit hit;
private final Map<String, DocumentField> fields;
public SearchHitBuilder(int docId) {
hit = new SearchHit(docId);
fields = new HashMap<>();
}
public SearchHitBuilder setSource(String sourceJson) {
hit.sourceRef(new BytesArray(sourceJson));
return this;
}
public SearchHit build() {
if (!fields.isEmpty()) {
hit.fields(fields);
}
return hit;
}
}
}

View File

@ -118,3 +118,64 @@ teardown:
transform_id: "airline-transform-start-stop"
- match: { stopped: true }
---
"Test start/stop only starts/stops specified transform":
- do:
data_frame.put_data_frame_transform:
transform_id: "airline-transform-start-later"
body: >
{
"source": { "index": "airline-data" },
"dest": { "index": "airline-data-start-later" },
"pivot": {
"group_by": { "airline": {"terms": {"field": "airline"}}},
"aggs": {"avg_response": {"avg": {"field": "responsetime"}}}
}
}
- do:
data_frame.start_data_frame_transform:
transform_id: "airline-transform-start-stop"
- match: { started: true }
- do:
data_frame.get_data_frame_transform_stats:
transform_id: "airline-transform-start-stop"
- match: { count: 1 }
- match: { transforms.0.id: "airline-transform-start-stop" }
- match: { transforms.0.state.indexer_state: "started" }
- match: { transforms.0.state.task_state: "started" }
- do:
data_frame.get_data_frame_transform_stats:
transform_id: "airline-transform-start-later"
- match: { count: 1 }
- match: { transforms.0.id: "airline-transform-start-later" }
- match: { transforms.0.state.indexer_state: "stopped" }
- match: { transforms.0.state.task_state: "stopped" }
- do:
data_frame.start_data_frame_transform:
transform_id: "airline-transform-start-later"
- match: { started: true }
- do:
data_frame.stop_data_frame_transform:
transform_id: "airline-transform-start-stop"
- match: { stopped: true }
- do:
data_frame.get_data_frame_transform_stats:
transform_id: "airline-transform-start-later"
- match: { count: 1 }
- match: { transforms.0.id: "airline-transform-start-later" }
- match: { transforms.0.state.indexer_state: "started" }
- match: { transforms.0.state.task_state: "started" }
- do:
data_frame.stop_data_frame_transform:
transform_id: "airline-transform-start-later"
- match: { stopped: true }
- do:
data_frame.delete_data_frame_transform:
transform_id: "airline-transform-start-later"