Persistent Tasks: require correct allocation id for status updates (#923)

In order to prevent tasks state updates by stale executors, this commit adds a check for correct allocation id during status update operation.
This commit is contained in:
Igor Motov 2017-04-01 18:17:07 -04:00 committed by Martijn van Groningen
parent 6ca044736e
commit 1b0f5b9572
No known key found for this signature in database
GPG Key ID: AB236F4FCF2AF12A
11 changed files with 62 additions and 47 deletions

View File

@ -32,6 +32,7 @@ import java.util.concurrent.atomic.AtomicReference;
*/
public class AllocatedPersistentTask extends CancellableTask {
private long persistentTaskId;
private long allocationId;
private final AtomicReference<State> state;
@Nullable
@ -70,16 +71,17 @@ public class AllocatedPersistentTask extends CancellableTask {
* This doesn't affect the status of this allocated task.
*/
public void updatePersistentStatus(Task.Status status, PersistentTasksService.PersistentTaskOperationListener listener) {
persistentTasksService.updateStatus(persistentTaskId, status, listener);
persistentTasksService.updateStatus(persistentTaskId, allocationId, status, listener);
}
public long getPersistentTaskId() {
return persistentTaskId;
}
void init(PersistentTasksService persistentTasksService, long persistentTaskId) {
void init(PersistentTasksService persistentTasksService, long persistentTaskId, long allocationId) {
this.persistentTasksService = persistentTasksService;
this.persistentTaskId = persistentTaskId;
this.allocationId = allocationId;
}
public Exception getFailure() {

View File

@ -25,6 +25,8 @@ import org.elasticsearch.transport.TransportResponse.Empty;
/**
* This component is responsible for execution of persistent tasks.
*
* It abstracts away the execution of tasks and greatly simplifies testing of PersistentTasksNodeService
*/
public class NodePersistentTasksExecutor {
private final ThreadPool threadPool;

View File

@ -192,18 +192,24 @@ public class PersistentTasksClusterService extends AbstractComponent implements
* Update task status
*
* @param id the id of a persistent task
* @param allocationId the expected allocation id of the persistent task
* @param status new status
* @param listener the listener that will be called when task is removed
*/
public void updatePersistentTaskStatus(long id, Task.Status status, ActionListener<Empty> listener) {
public void updatePersistentTaskStatus(long id, long allocationId, Task.Status status, ActionListener<Empty> listener) {
clusterService.submitStateUpdateTask("update task status", new ClusterStateUpdateTask() {
@Override
public ClusterState execute(ClusterState currentState) throws Exception {
PersistentTasksCustomMetaData.Builder tasksInProgress = builder(currentState);
if (tasksInProgress.hasTask(id)) {
if (tasksInProgress.hasTask(id, allocationId)) {
return update(currentState, tasksInProgress.updateTaskStatus(id, status));
} else {
throw new ResourceNotFoundException("the task with id {} doesn't exist", id);
if (tasksInProgress.hasTask(id)) {
logger.warn("trying to update status on task {} with unexpected allocation id {}", id, allocationId);
} else {
logger.warn("trying to update status on non-existing task {}", id);
}
throw new ResourceNotFoundException("the task with id {} and allocation id {} doesn't exist", id, allocationId);
}
}

View File

@ -634,6 +634,17 @@ public final class PersistentTasksCustomMetaData extends AbstractNamedDiffable<M
return tasks.containsKey(taskId);
}
/**
* Checks if the task is currently present in the list and has the right allocation id
*/
public boolean hasTask(long taskId, long allocationId) {
PersistentTask<?> taskInProgress = tasks.get(taskId);
if (taskInProgress != null) {
return taskInProgress.getAllocationId() == allocationId;
}
return false;
}
/**
* Returns the id of the last added task
*/

View File

@ -101,27 +101,6 @@ public abstract class PersistentTasksExecutor<Request extends PersistentTaskRequ
}
/**
* Updates the persistent task status in the cluster state.
* <p>
* The status can be used to store the current progress of the task or provide an insight for the
* task allocator about the state of the currently running tasks.
*/
protected void updatePersistentTaskStatus(AllocatedPersistentTask task, Task.Status status, ActionListener<Empty> listener) {
persistentTasksService.updateStatus(task.getPersistentTaskId(), status,
new PersistentTaskOperationListener() {
@Override
public void onResponse(long taskId) {
listener.onResponse(Empty.INSTANCE);
}
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
});
}
/**
* This operation will be executed on the executor node.
* <p>

View File

@ -131,7 +131,7 @@ public class PersistentTasksNodeService extends AbstractComponent implements Clu
taskInProgress.getRequest());
boolean processed = false;
try {
task.init(persistentTasksService, taskInProgress.getId());
task.init(persistentTasksService, taskInProgress.getId(), taskInProgress.getAllocationId());
PersistentTaskListener listener = new PersistentTaskListener(task);
try {
runningTasks.put(new PersistentTaskId(taskInProgress.getId(), taskInProgress.getAllocationId()), task);

View File

@ -100,10 +100,14 @@ public class PersistentTasksService extends AbstractComponent {
}
/**
* Updates status of the persistent task
* Updates status of the persistent task.
*
* Persistent task implementers shouldn't call this method directly and use
* {@link AllocatedPersistentTask#updatePersistentStatus} instead
*/
public void updateStatus(long taskId, Task.Status status, PersistentTaskOperationListener listener) {
UpdatePersistentTaskStatusAction.Request updateStatusRequest = new UpdatePersistentTaskStatusAction.Request(taskId, status);
void updateStatus(long taskId, long allocationId, Task.Status status, PersistentTaskOperationListener listener) {
UpdatePersistentTaskStatusAction.Request updateStatusRequest =
new UpdatePersistentTaskStatusAction.Request(taskId, allocationId, status);
try {
client.execute(UpdatePersistentTaskStatusAction.INSTANCE, updateStatusRequest, ActionListener.wrap(
o -> listener.onResponse(taskId), listener::onFailure));

View File

@ -68,15 +68,16 @@ public class UpdatePersistentTaskStatusAction extends Action<UpdatePersistentTas
public static class Request extends MasterNodeRequest<Request> {
private long taskId;
private long allocationId;
private Task.Status status;
public Request() {
}
public Request(long taskId, Task.Status status) {
public Request(long taskId, long allocationId, Task.Status status) {
this.taskId = taskId;
this.allocationId = allocationId;
this.status = status;
}
@ -84,6 +85,10 @@ public class UpdatePersistentTaskStatusAction extends Action<UpdatePersistentTas
this.taskId = taskId;
}
public void setAllocationId(long allocationId) {
this.allocationId = allocationId;
}
public void setStatus(Task.Status status) {
this.status = status;
}
@ -92,6 +97,7 @@ public class UpdatePersistentTaskStatusAction extends Action<UpdatePersistentTas
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
taskId = in.readLong();
allocationId = in.readLong();
status = in.readOptionalNamedWriteable(Task.Status.class);
}
@ -99,6 +105,7 @@ public class UpdatePersistentTaskStatusAction extends Action<UpdatePersistentTas
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeLong(taskId);
out.writeLong(allocationId);
out.writeOptionalNamedWriteable(status);
}
@ -112,13 +119,13 @@ public class UpdatePersistentTaskStatusAction extends Action<UpdatePersistentTas
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return taskId == request.taskId &&
return taskId == request.taskId && allocationId == request.allocationId &&
Objects.equals(status, request.status);
}
@Override
public int hashCode() {
return Objects.hash(taskId, status);
return Objects.hash(taskId, allocationId, status);
}
}
@ -207,7 +214,8 @@ public class UpdatePersistentTaskStatusAction extends Action<UpdatePersistentTas
@Override
protected final void masterOperation(final Request request, ClusterState state, final ActionListener<Response> listener) {
persistentTasksClusterService.updatePersistentTaskStatus(request.taskId, request.status, new ActionListener<Empty>() {
persistentTasksClusterService.updatePersistentTaskStatus(request.taskId, request.allocationId, request.status,
new ActionListener<Empty>() {
@Override
public void onResponse(Empty empty) {
listener.onResponse(new Response(true));

View File

@ -19,14 +19,16 @@
package org.elasticsearch.persistent;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.BaseFuture;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.persistent.PersistentTasksService.WaitForPersistentTaskStatusListener;
import org.elasticsearch.persistent.TestPersistentTasksPlugin.Status;
import org.elasticsearch.persistent.TestPersistentTasksPlugin.TestPersistentTasksExecutor;
import org.elasticsearch.persistent.TestPersistentTasksPlugin.TestRequest;
import org.elasticsearch.persistent.TestPersistentTasksPlugin.TestTasksRequestBuilder;
@ -37,6 +39,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertThrows;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
@ -63,17 +66,11 @@ public class PersistentTasksExecutorIT extends ESIntegTestCase {
assertNoRunningTasks();
}
public static class PersistentTaskOperationFuture extends BaseFuture<Long> implements WaitForPersistentTaskStatusListener {
public static class PersistentTaskOperationFuture extends PlainActionFuture<Long> implements WaitForPersistentTaskStatusListener {
@Override
public void onResponse(long taskId) {
set(taskId);
}
@Override
public void onFailure(Exception e) {
setException(e);
}
}
public void testPersistentActionFailure() throws Exception {
@ -200,7 +197,12 @@ public class PersistentTasksExecutorIT extends ESIntegTestCase {
persistentTasksService.waitForPersistentTaskStatus(taskId,
task -> false, TimeValue.timeValueMillis(10), future1);
expectThrows(Exception.class, future1::get);
assertThrows(future1, IllegalStateException.class, "timed out after 10ms");
PersistentTaskOperationFuture failedUpdateFuture = new PersistentTaskOperationFuture();
persistentTasksService.updateStatus(taskId, -1, new Status("should fail"), failedUpdateFuture);
assertThrows(failedUpdateFuture, ResourceNotFoundException.class, "the task with id " + taskId +
" and allocation id -1 doesn't exist");
// Wait for the task to disappear
PersistentTaskOperationFuture future2 = new PersistentTaskOperationFuture();

View File

@ -63,6 +63,7 @@ import org.elasticsearch.transport.TransportResponse.Empty;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData.Assignment;
import org.elasticsearch.persistent.PersistentTasksService.PersistentTaskOperationListener;
import java.io.IOException;
import java.util.ArrayList;
@ -365,9 +366,9 @@ public class TestPersistentTasksPlugin extends Plugin implements ActionPlugin {
CountDownLatch latch = new CountDownLatch(1);
Status status = new Status("phase " + phase.incrementAndGet());
logger.info("updating the task status to {}", status);
updatePersistentTaskStatus(task, status, new ActionListener<Empty>() {
task.updatePersistentStatus(status, new PersistentTaskOperationListener() {
@Override
public void onResponse(Empty empty) {
public void onResponse(long taskId) {
logger.info("updating was successful");
latch.countDown();
}

View File

@ -30,7 +30,7 @@ public class UpdatePersistentTaskRequestTests extends AbstractStreamableTestCase
@Override
protected Request createTestInstance() {
return new Request(randomLong(), new Status(randomAsciiOfLength(10)));
return new Request(randomLong(), randomLong(), new Status(randomAsciiOfLength(10)));
}
@Override