[ML][Data Frame] Fixes failure state tests and failure setting handling (#44645) (#44698)

* [ML][Data Frame] fixing flaky test

* adjusting frequency

* fixing tests

* addressing PR comments
This commit is contained in:
Benjamin Trent 2019-07-23 08:33:12 -05:00 committed by GitHub
parent 16c8e18013
commit 6f53865fde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 76 additions and 87 deletions

View File

@ -6,12 +6,7 @@
package org.elasticsearch.xpack.dataframe.integration; package org.elasticsearch.xpack.dataframe.integration;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.elasticsearch.client.Request; import org.elasticsearch.client.Request;
import org.elasticsearch.client.dataframe.transforms.DataFrameTransformTaskState;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.junit.Before; import org.junit.Before;
@ -22,9 +17,7 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -730,45 +723,6 @@ public class DataFramePivotRestIT extends DataFrameRestTestCase {
assertEquals(4.47169811, actual.doubleValue(), 0.000001); assertEquals(4.47169811, actual.doubleValue(), 0.000001);
} }
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/pull/44583")
public void testBulkIndexFailuresCauseTaskToFail() throws Exception {
String transformId = "bulk-failure-pivot";
String dataFrameIndex = "pivot-failure-index";
createPivotReviewsTransform(transformId, dataFrameIndex, null, null, null);
try (XContentBuilder builder = jsonBuilder()) {
builder.startObject();
{
builder.startObject("mappings")
.startObject("properties")
.startObject("reviewer")
// This type should cause mapping coercion type conflict on bulk index
.field("type", "long")
.endObject()
.endObject()
.endObject();
}
builder.endObject();
final StringEntity entity = new StringEntity(Strings.toString(builder), ContentType.APPLICATION_JSON);
Request req = new Request("PUT", dataFrameIndex);
req.setEntity(entity);
client().performRequest(req);
}
startDataframeTransform(transformId, false, null);
assertBusy(() -> assertEquals(DataFrameTransformTaskState.FAILED.value(), getDataFrameTaskState(transformId)),
120,
TimeUnit.SECONDS);
Map<?, ?> state = getDataFrameState(transformId);
assertThat((String) XContentMapValues.extractValue("state.reason", state),
containsString("task encountered more than 10 failures; latest failure: Bulk index experienced failures."));
// Force stop the transform as bulk indexing caused it to go into a failed state
stopDataFrameTransform(transformId, true);
deleteIndex(dataFrameIndex);
}
private void assertOnePivotValue(String query, double expected) throws IOException { private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query); Map<String, Object> searchResult = getAsMap(query);

View File

@ -51,11 +51,9 @@ public abstract class DataFrameRestTestCase extends ESRestTestCase {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build(); return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
} }
protected void createReviewsIndex(String indexName) throws IOException { protected void createReviewsIndex(String indexName, int numDocs) throws IOException {
int[] distributionTable = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 3, 3, 2, 1, 1, 1}; int[] distributionTable = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 3, 3, 2, 1, 1, 1};
final int numDocs = 1000;
// create mapping // create mapping
try (XContentBuilder builder = jsonBuilder()) { try (XContentBuilder builder = jsonBuilder()) {
builder.startObject(); builder.startObject();
@ -146,6 +144,10 @@ public abstract class DataFrameRestTestCase extends ESRestTestCase {
createReviewsIndex(REVIEWS_INDEX_NAME); createReviewsIndex(REVIEWS_INDEX_NAME);
} }
protected void createReviewsIndex(String indexName) throws IOException {
createReviewsIndex(indexName, 1000);
}
protected void createPivotReviewsTransform(String transformId, String dataFrameIndex, String query) throws IOException { protected void createPivotReviewsTransform(String transformId, String dataFrameIndex, String query) throws IOException {
createPivotReviewsTransform(transformId, dataFrameIndex, query, null); createPivotReviewsTransform(transformId, dataFrameIndex, query, null);
} }

View File

@ -8,7 +8,6 @@ package org.elasticsearch.xpack.dataframe.integration;
import org.apache.http.entity.ContentType; import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity; import org.apache.http.entity.StringEntity;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.client.Request; import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
@ -17,6 +16,7 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformTaskState; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformTaskState;
import org.junit.After; import org.junit.After;
import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -27,13 +27,21 @@ import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.oneOf;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/pull/44583")
public class DataFrameTaskFailedStateIT extends DataFrameRestTestCase { public class DataFrameTaskFailedStateIT extends DataFrameRestTestCase {
private static final String TRANSFORM_ID = "failure_pivot_1"; private static final String TRANSFORM_ID = "failure_pivot_1";
@Before
public void setClusterSettings() throws IOException {
// Make sure we never retry on failure to speed up the test
Request addFailureRetrySetting = new Request("PUT", "/_cluster/settings");
addFailureRetrySetting.setJsonEntity(
"{\"persistent\": {\"xpack.data_frame.num_transform_failure_retries\": \"" + 0 + "\"}}");
client().performRequest(addFailureRetrySetting);
}
@After @After
public void cleanUpPotentiallyFailedTransform() throws Exception { public void cleanUpPotentiallyFailedTransform() throws Exception {
// If the tests failed in the middle, we should force stop it. This prevents other transform tests from failing due // If the tests failed in the middle, we should force stop it. This prevents other transform tests from failing due
@ -43,14 +51,14 @@ public class DataFrameTaskFailedStateIT extends DataFrameRestTestCase {
} }
public void testForceStopFailedTransform() throws Exception { public void testForceStopFailedTransform() throws Exception {
createReviewsIndex(); createReviewsIndex(REVIEWS_INDEX_NAME, 10);
String dataFrameIndex = "failure_pivot_reviews"; String dataFrameIndex = "failure_pivot_reviews";
createDestinationIndexWithBadMapping(dataFrameIndex); createDestinationIndexWithBadMapping(dataFrameIndex);
createContinuousPivotReviewsTransform(TRANSFORM_ID, dataFrameIndex, null); createContinuousPivotReviewsTransform(TRANSFORM_ID, dataFrameIndex, null);
startDataframeTransform(TRANSFORM_ID, false); startDataframeTransform(TRANSFORM_ID, false);
awaitState(TRANSFORM_ID, DataFrameTransformTaskState.FAILED); awaitState(TRANSFORM_ID, DataFrameTransformTaskState.FAILED);
Map<?, ?> fullState = getDataFrameState(TRANSFORM_ID); Map<?, ?> fullState = getDataFrameState(TRANSFORM_ID);
final String failureReason = "task encountered more than 10 failures; latest failure: " + final String failureReason = "task encountered more than 0 failures; latest failure: " +
"Bulk index experienced failures. See the logs of the node running the transform for details."; "Bulk index experienced failures. See the logs of the node running the transform for details.";
// Verify we have failed for the expected reason // Verify we have failed for the expected reason
assertThat(XContentMapValues.extractValue("state.reason", fullState), assertThat(XContentMapValues.extractValue("state.reason", fullState),
@ -69,21 +77,20 @@ public class DataFrameTaskFailedStateIT extends DataFrameRestTestCase {
awaitState(TRANSFORM_ID, DataFrameTransformTaskState.STOPPED); awaitState(TRANSFORM_ID, DataFrameTransformTaskState.STOPPED);
fullState = getDataFrameState(TRANSFORM_ID); fullState = getDataFrameState(TRANSFORM_ID);
// Verify we have failed for the expected reason
assertThat(XContentMapValues.extractValue("state.reason", fullState), assertThat(XContentMapValues.extractValue("state.reason", fullState),
is(nullValue())); is(nullValue()));
} }
public void testForceStartFailedTransform() throws Exception { public void testForceStartFailedTransform() throws Exception {
createReviewsIndex(); createReviewsIndex(REVIEWS_INDEX_NAME, 10);
String dataFrameIndex = "failure_pivot_reviews"; String dataFrameIndex = "failure_pivot_reviews";
createDestinationIndexWithBadMapping(dataFrameIndex); createDestinationIndexWithBadMapping(dataFrameIndex);
createContinuousPivotReviewsTransform(TRANSFORM_ID, dataFrameIndex, null); createContinuousPivotReviewsTransform(TRANSFORM_ID, dataFrameIndex, null);
startDataframeTransform(TRANSFORM_ID, false); startDataframeTransform(TRANSFORM_ID, false);
awaitState(TRANSFORM_ID, DataFrameTransformTaskState.FAILED); awaitState(TRANSFORM_ID, DataFrameTransformTaskState.FAILED);
Map<?, ?> fullState = getDataFrameState(TRANSFORM_ID); Map<?, ?> fullState = getDataFrameState(TRANSFORM_ID);
final String failureReason = "task encountered more than 10 failures; latest failure: " + final String failureReason = "task encountered more than 0 failures; latest failure: " +
"Bulk index experienced failures. See the logs of the node running the transform for details."; "Bulk index experienced failures. See the logs of the node running the transform for details.";
// Verify we have failed for the expected reason // Verify we have failed for the expected reason
assertThat(XContentMapValues.extractValue("state.reason", fullState), assertThat(XContentMapValues.extractValue("state.reason", fullState),
@ -101,23 +108,15 @@ public class DataFrameTaskFailedStateIT extends DataFrameRestTestCase {
deleteIndex(dataFrameIndex); deleteIndex(dataFrameIndex);
// Force start the data frame to indicate failure correction // Force start the data frame to indicate failure correction
startDataframeTransform(TRANSFORM_ID, true); startDataframeTransform(TRANSFORM_ID, true);
// Wait for data to be indexed appropriately and refresh for search
waitForDataFrameCheckpoint(TRANSFORM_ID);
refreshIndex(dataFrameIndex);
// Verify that we have started and that our reason is cleared // Verify that we have started and that our reason is cleared
fullState = getDataFrameState(TRANSFORM_ID); fullState = getDataFrameState(TRANSFORM_ID);
assertThat(XContentMapValues.extractValue("state.reason", fullState), is(nullValue())); assertThat(XContentMapValues.extractValue("state.reason", fullState), is(nullValue()));
assertThat(XContentMapValues.extractValue("state.task_state", fullState), equalTo("started")); assertThat(XContentMapValues.extractValue("state.task_state", fullState), equalTo("started"));
assertThat(XContentMapValues.extractValue("state.indexer_state", fullState), equalTo("started")); assertThat(XContentMapValues.extractValue("state.indexer_state", fullState), is(oneOf("started", "indexing")));
assertThat((int)XContentMapValues.extractValue("stats.index_failures", fullState), greaterThan(0)); assertThat(XContentMapValues.extractValue("stats.index_failures", fullState), equalTo(1));
// get and check some users to verify we restarted stopDataFrameTransform(TRANSFORM_ID, true);
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_0", 3.776978417);
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_5", 3.72);
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_11", 3.846153846);
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_20", 3.769230769);
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_26", 3.918918918);
} }
private void awaitState(String transformId, DataFrameTransformTaskState state) throws Exception { private void awaitState(String transformId, DataFrameTransformTaskState state) throws Exception {

View File

@ -219,8 +219,15 @@ public class DataFrame extends Plugin implements ActionPlugin, PersistentTaskPlu
assert dataFrameAuditor.get() != null; assert dataFrameAuditor.get() != null;
assert dataFrameTransformsCheckpointService.get() != null; assert dataFrameTransformsCheckpointService.get() != null;
return Collections.singletonList(new DataFrameTransformPersistentTasksExecutor(client, dataFrameTransformsConfigManager.get(), return Collections.singletonList(
dataFrameTransformsCheckpointService.get(), schedulerEngine.get(), dataFrameAuditor.get(), threadPool)); new DataFrameTransformPersistentTasksExecutor(client,
dataFrameTransformsConfigManager.get(),
dataFrameTransformsCheckpointService.get(),
schedulerEngine.get(),
dataFrameAuditor.get(),
threadPool,
clusterService,
settingsModule.getSettings()));
} }
public List<Setting<?>> getSettings() { public List<Setting<?>> getSettings() {

View File

@ -32,7 +32,6 @@ public class TransportStartDataFrameTransformTaskAction extends
TransportTasksAction<DataFrameTransformTask, StartDataFrameTransformTaskAction.Request, TransportTasksAction<DataFrameTransformTask, StartDataFrameTransformTaskAction.Request,
StartDataFrameTransformTaskAction.Response, StartDataFrameTransformTaskAction.Response> { StartDataFrameTransformTaskAction.Response, StartDataFrameTransformTaskAction.Response> {
private volatile int numFailureRetries;
private final XPackLicenseState licenseState; private final XPackLicenseState licenseState;
@Inject @Inject
@ -42,8 +41,6 @@ public class TransportStartDataFrameTransformTaskAction extends
StartDataFrameTransformTaskAction.Request::new, StartDataFrameTransformTaskAction.Response::new, StartDataFrameTransformTaskAction.Request::new, StartDataFrameTransformTaskAction.Response::new,
StartDataFrameTransformTaskAction.Response::new, ThreadPool.Names.SAME); StartDataFrameTransformTaskAction.Response::new, ThreadPool.Names.SAME);
this.licenseState = licenseState; this.licenseState = licenseState;
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(DataFrameTransformTask.NUM_FAILURE_RETRIES_SETTING, this::setNumFailureRetries);
} }
@Override @Override
@ -62,7 +59,7 @@ public class TransportStartDataFrameTransformTaskAction extends
protected void taskOperation(StartDataFrameTransformTaskAction.Request request, DataFrameTransformTask transformTask, protected void taskOperation(StartDataFrameTransformTaskAction.Request request, DataFrameTransformTask transformTask,
ActionListener<StartDataFrameTransformTaskAction.Response> listener) { ActionListener<StartDataFrameTransformTaskAction.Response> listener) {
if (transformTask.getTransformId().equals(request.getId())) { if (transformTask.getTransformId().equals(request.getId())) {
transformTask.setNumFailureRetries(numFailureRetries).start(null, listener); transformTask.start(null, listener);
} else { } else {
listener.onFailure(new RuntimeException("ID of data frame transform task [" + transformTask.getTransformId() listener.onFailure(new RuntimeException("ID of data frame transform task [" + transformTask.getTransformId()
+ "] does not match request's ID [" + request.getId() + "]")); + "] does not match request's ID [" + request.getId() + "]"));
@ -93,8 +90,4 @@ public class TransportStartDataFrameTransformTaskAction extends
boolean allStarted = tasks.stream().allMatch(StartDataFrameTransformTaskAction.Response::isStarted); boolean allStarted = tasks.stream().allMatch(StartDataFrameTransformTaskAction.Response::isStarted);
return new StartDataFrameTransformTaskAction.Response(allStarted); return new StartDataFrameTransformTaskAction.Response(allStarted);
} }
void setNumFailureRetries(int numFailureRetries) {
this.numFailureRetries = numFailureRetries;
}
} }

View File

@ -18,7 +18,9 @@ import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexRoutingTable;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.AllocatedPersistentTask;
import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTaskState;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
@ -61,13 +63,16 @@ public class DataFrameTransformPersistentTasksExecutor extends PersistentTasksEx
private final SchedulerEngine schedulerEngine; private final SchedulerEngine schedulerEngine;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final DataFrameAuditor auditor; private final DataFrameAuditor auditor;
private volatile int numFailureRetries;
public DataFrameTransformPersistentTasksExecutor(Client client, public DataFrameTransformPersistentTasksExecutor(Client client,
DataFrameTransformsConfigManager transformsConfigManager, DataFrameTransformsConfigManager transformsConfigManager,
DataFrameTransformsCheckpointService dataFrameTransformsCheckpointService, DataFrameTransformsCheckpointService dataFrameTransformsCheckpointService,
SchedulerEngine schedulerEngine, SchedulerEngine schedulerEngine,
DataFrameAuditor auditor, DataFrameAuditor auditor,
ThreadPool threadPool) { ThreadPool threadPool,
ClusterService clusterService,
Settings settings) {
super(DataFrameField.TASK_NAME, DataFrame.TASK_THREAD_POOL_NAME); super(DataFrameField.TASK_NAME, DataFrame.TASK_THREAD_POOL_NAME);
this.client = client; this.client = client;
this.transformsConfigManager = transformsConfigManager; this.transformsConfigManager = transformsConfigManager;
@ -75,6 +80,9 @@ public class DataFrameTransformPersistentTasksExecutor extends PersistentTasksEx
this.schedulerEngine = schedulerEngine; this.schedulerEngine = schedulerEngine;
this.auditor = auditor; this.auditor = auditor;
this.threadPool = threadPool; this.threadPool = threadPool;
this.numFailureRetries = DataFrameTransformTask.NUM_FAILURE_RETRIES_SETTING.get(settings);
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(DataFrameTransformTask.NUM_FAILURE_RETRIES_SETTING, this::setNumFailureRetries);
} }
@Override @Override
@ -286,7 +294,11 @@ public class DataFrameTransformPersistentTasksExecutor extends PersistentTasksEx
Long previousCheckpoint, Long previousCheckpoint,
ActionListener<StartDataFrameTransformTaskAction.Response> listener) { ActionListener<StartDataFrameTransformTaskAction.Response> listener) {
buildTask.initializeIndexer(indexerBuilder); buildTask.initializeIndexer(indexerBuilder);
buildTask.start(previousCheckpoint, listener); buildTask.setNumFailureRetries(numFailureRetries).start(previousCheckpoint, listener);
}
private void setNumFailureRetries(int numFailureRetries) {
this.numFailureRetries = numFailureRetries;
} }
@Override @Override

View File

@ -329,19 +329,33 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
} }
synchronized void markAsFailed(String reason, ActionListener<Void> listener) { synchronized void markAsFailed(String reason, ActionListener<Void> listener) {
taskState.set(DataFrameTransformTaskState.FAILED);
stateReason.set(reason);
auditor.error(transform.getId(), reason); auditor.error(transform.getId(), reason);
// We should not keep retrying. Either the task will be stopped, or started // We should not keep retrying. Either the task will be stopped, or started
// If it is started again, it is registered again. // If it is started again, it is registered again.
deregisterSchedulerJob(); deregisterSchedulerJob();
DataFrameTransformState newState = new DataFrameTransformState(
DataFrameTransformTaskState.FAILED,
initialIndexerState,
initialPosition,
currentCheckpoint.get(),
reason,
getIndexer() == null ? null : getIndexer().getProgress());
// Even though the indexer information is persisted to an index, we still need DataFrameTransformTaskState in the clusterstate // Even though the indexer information is persisted to an index, we still need DataFrameTransformTaskState in the clusterstate
// This keeps track of STARTED, FAILED, STOPPED // This keeps track of STARTED, FAILED, STOPPED
// This is because a FAILED state can occur because we cannot read the config from the internal index, which would imply that // This is because a FAILED state can occur because we cannot read the config from the internal index, which would imply that
// we could not read the previous state information from said index. // we could not read the previous state information from said index.
persistStateToClusterState(getState(), ActionListener.wrap( persistStateToClusterState(newState, ActionListener.wrap(
r -> listener.onResponse(null), r -> {
listener::onFailure taskState.set(DataFrameTransformTaskState.FAILED);
stateReason.set(reason);
listener.onResponse(null);
},
e -> {
logger.error("Failed to set task state as failed to cluster state", e);
taskState.set(DataFrameTransformTaskState.FAILED);
stateReason.set(reason);
listener.onFailure(e);
}
)); ));
} }

View File

@ -21,6 +21,8 @@ import org.elasticsearch.cluster.routing.RecoverySource;
import org.elasticsearch.cluster.routing.RoutingTable; import org.elasticsearch.cluster.routing.RoutingTable;
import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.cluster.routing.UnassignedInfo;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.Index; import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
@ -42,6 +44,7 @@ import java.util.List;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class DataFrameTransformPersistentTasksExecutorTests extends ESTestCase { public class DataFrameTransformPersistentTasksExecutorTests extends ESTestCase {
@ -99,12 +102,17 @@ public class DataFrameTransformPersistentTasksExecutorTests extends ESTestCase {
DataFrameTransformsConfigManager transformsConfigManager = new DataFrameTransformsConfigManager(client, xContentRegistry()); DataFrameTransformsConfigManager transformsConfigManager = new DataFrameTransformsConfigManager(client, xContentRegistry());
DataFrameTransformsCheckpointService dataFrameTransformsCheckpointService = new DataFrameTransformsCheckpointService(client, DataFrameTransformsCheckpointService dataFrameTransformsCheckpointService = new DataFrameTransformsCheckpointService(client,
transformsConfigManager); transformsConfigManager);
ClusterSettings cSettings = new ClusterSettings(Settings.EMPTY,
Collections.singleton(DataFrameTransformTask.NUM_FAILURE_RETRIES_SETTING));
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(cSettings);
DataFrameTransformPersistentTasksExecutor executor = new DataFrameTransformPersistentTasksExecutor(client, DataFrameTransformPersistentTasksExecutor executor = new DataFrameTransformPersistentTasksExecutor(client,
transformsConfigManager, transformsConfigManager,
dataFrameTransformsCheckpointService, mock(SchedulerEngine.class), dataFrameTransformsCheckpointService, mock(SchedulerEngine.class),
new DataFrameAuditor(client, ""), new DataFrameAuditor(client, ""),
mock(ThreadPool.class)); mock(ThreadPool.class),
clusterService,
Settings.EMPTY);
assertThat(executor.getAssignment(new DataFrameTransform("new-task-id", Version.CURRENT, null), cs).getExecutorNode(), assertThat(executor.getAssignment(new DataFrameTransform("new-task-id", Version.CURRENT, null), cs).getExecutorNode(),
equalTo("current-data-node-with-1-tasks")); equalTo("current-data-node-with-1-tasks"));