Verify AllocationIDs in replication actions (#20320)

Replicated operation consist of a routing action (the original), which is in charge of sending the operation to the primary shard, a primary action which executes the operation on the resolved primary and replica actions which performs the operation on a specific replica. This commit adds the targeted shard's allocation id to the primary and replica actions and makes sure that those match the shard the actions end up executing on.

This helps preventing extremely rare failure mode where a shard moves off a node and back to it, all between an action is sent and the time it's processed. 

For example:
1) Primary action is sent to a relocating primary on node A.
2) The primary finishes relocation to node B and start relocating back.
3) The relocation back gets to the phase and opens up the target engine, on the original node, node A.
4) The primary action is executed on the target engine before the relocation finishes, at which the shard copy on node B is still the official primary - i.e., it is executed on the wrong primary.
This commit is contained in:
Boaz Leskes 2016-09-06 14:32:48 +02:00 committed by GitHub
parent b6bf20c2da
commit c56cd46162
5 changed files with 377 additions and 86 deletions

View File

@ -38,10 +38,12 @@ import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.routing.AllocationId;
import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.settings.Settings;
@ -53,14 +55,17 @@ import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.IndexShardState;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardNotFoundException;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportChannelResponseHandler;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
@ -69,6 +74,7 @@ import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Supplier;
@ -115,9 +121,12 @@ public abstract class TransportReplicationAction<
this.transportPrimaryAction = actionName + "[p]";
this.transportReplicaAction = actionName + "[r]";
transportService.registerRequestHandler(actionName, request, ThreadPool.Names.SAME, new OperationTransportHandler());
transportService.registerRequestHandler(transportPrimaryAction, request, executor, new PrimaryOperationTransportHandler());
transportService.registerRequestHandler(transportPrimaryAction, () -> new ConcreteShardRequest<>(request), executor,
new PrimaryOperationTransportHandler());
// we must never reject on because of thread pool capacity on replicas
transportService.registerRequestHandler(transportReplicaAction, replicaRequest, executor, true, true,
transportService.registerRequestHandler(transportReplicaAction,
() -> new ConcreteShardRequest<>(replicaRequest),
executor, true, true,
new ReplicaOperationTransportHandler());
this.transportOptions = transportOptions();
@ -163,7 +172,7 @@ public abstract class TransportReplicationAction<
/**
* Synchronous replica operation on nodes with replica copies. This is done under the lock form
* {@link #acquireReplicaOperationLock(ShardId, long, ActionListener)}.
* {@link #acquireReplicaOperationLock(ShardId, long, String, ActionListener)}.
*/
protected abstract ReplicaResult shardOperationOnReplica(ReplicaRequest shardRequest);
@ -230,33 +239,36 @@ public abstract class TransportReplicationAction<
}
}
class PrimaryOperationTransportHandler implements TransportRequestHandler<Request> {
class PrimaryOperationTransportHandler implements TransportRequestHandler<ConcreteShardRequest<Request>> {
@Override
public void messageReceived(final Request request, final TransportChannel channel) throws Exception {
public void messageReceived(final ConcreteShardRequest<Request> request, final TransportChannel channel) throws Exception {
throw new UnsupportedOperationException("the task parameter is required for this operation");
}
@Override
public void messageReceived(Request request, TransportChannel channel, Task task) {
new AsyncPrimaryAction(request, channel, (ReplicationTask) task).run();
public void messageReceived(ConcreteShardRequest<Request> request, TransportChannel channel, Task task) {
new AsyncPrimaryAction(request.request, request.targetAllocationID, channel, (ReplicationTask) task).run();
}
}
class AsyncPrimaryAction extends AbstractRunnable implements ActionListener<PrimaryShardReference> {
private final Request request;
/** targetAllocationID of the shard this request is meant for */
private final String targetAllocationID;
private final TransportChannel channel;
private final ReplicationTask replicationTask;
AsyncPrimaryAction(Request request, TransportChannel channel, ReplicationTask replicationTask) {
AsyncPrimaryAction(Request request, String targetAllocationID, TransportChannel channel, ReplicationTask replicationTask) {
this.request = request;
this.targetAllocationID = targetAllocationID;
this.channel = channel;
this.replicationTask = replicationTask;
}
@Override
protected void doRun() throws Exception {
acquirePrimaryShardReference(request.shardId(), this);
acquirePrimaryShardReference(request.shardId(), targetAllocationID, this);
}
@Override
@ -271,7 +283,9 @@ public abstract class TransportReplicationAction<
final ShardRouting primary = primaryShardReference.routingEntry();
assert primary.relocating() : "indexShard is marked as relocated but routing isn't" + primary;
DiscoveryNode relocatingNode = clusterService.state().nodes().get(primary.relocatingNodeId());
transportService.sendRequest(relocatingNode, transportPrimaryAction, request, transportOptions,
transportService.sendRequest(relocatingNode, transportPrimaryAction,
new ConcreteShardRequest<>(request, primary.allocationId().getRelocationId()),
transportOptions,
new TransportChannelResponseHandler<Response>(logger, channel, "rerouting indexing to target primary " + primary,
TransportReplicationAction.this::newResponseInstance) {
@ -391,15 +405,17 @@ public abstract class TransportReplicationAction<
}
}
class ReplicaOperationTransportHandler implements TransportRequestHandler<ReplicaRequest> {
class ReplicaOperationTransportHandler implements TransportRequestHandler<ConcreteShardRequest<ReplicaRequest>> {
@Override
public void messageReceived(final ReplicaRequest request, final TransportChannel channel) throws Exception {
public void messageReceived(final ConcreteShardRequest<ReplicaRequest> request, final TransportChannel channel)
throws Exception {
throw new UnsupportedOperationException("the task parameter is required for this operation");
}
@Override
public void messageReceived(ReplicaRequest request, TransportChannel channel, Task task) throws Exception {
new AsyncReplicaAction(request, channel, (ReplicationTask) task).run();
public void messageReceived(ConcreteShardRequest<ReplicaRequest> requestWithAID, TransportChannel channel, Task task)
throws Exception {
new AsyncReplicaAction(requestWithAID.request, requestWithAID.targetAllocationID, channel, (ReplicationTask) task).run();
}
}
@ -417,6 +433,8 @@ public abstract class TransportReplicationAction<
private final class AsyncReplicaAction extends AbstractRunnable implements ActionListener<Releasable> {
private final ReplicaRequest request;
// allocation id of the replica this request is meant for
private final String targetAllocationID;
private final TransportChannel channel;
/**
* The task on the node with the replica shard.
@ -426,10 +444,11 @@ public abstract class TransportReplicationAction<
// something we want to avoid at all costs
private final ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger, threadPool.getThreadContext());
AsyncReplicaAction(ReplicaRequest request, TransportChannel channel, ReplicationTask task) {
AsyncReplicaAction(ReplicaRequest request, String targetAllocationID, TransportChannel channel, ReplicationTask task) {
this.request = request;
this.channel = channel;
this.task = task;
this.targetAllocationID = targetAllocationID;
}
@Override
@ -464,7 +483,9 @@ public abstract class TransportReplicationAction<
String extraMessage = "action [" + transportReplicaAction + "], request[" + request + "]";
TransportChannelResponseHandler<TransportResponse.Empty> handler =
new TransportChannelResponseHandler<>(logger, channel, extraMessage, () -> TransportResponse.Empty.INSTANCE);
transportService.sendRequest(clusterService.localNode(), transportReplicaAction, request, handler);
transportService.sendRequest(clusterService.localNode(), transportReplicaAction,
new ConcreteShardRequest<>(request, targetAllocationID),
handler);
}
@Override
@ -501,7 +522,7 @@ public abstract class TransportReplicationAction<
protected void doRun() throws Exception {
setPhase(task, "replica");
assert request.shardId() != null : "request shardId must be set";
acquireReplicaOperationLock(request.shardId(), request.primaryTerm(), this);
acquireReplicaOperationLock(request.shardId(), request.primaryTerm(), targetAllocationID, this);
}
/**
@ -598,7 +619,7 @@ public abstract class TransportReplicationAction<
logger.trace("send action [{}] on primary [{}] for request [{}] with cluster state version [{}] to [{}] ",
transportPrimaryAction, request.shardId(), request, state.version(), primary.currentNodeId());
}
performAction(node, transportPrimaryAction, true);
performAction(node, transportPrimaryAction, true, new ConcreteShardRequest<>(request, primary.allocationId().getId()));
}
private void performRemoteAction(ClusterState state, ShardRouting primary, DiscoveryNode node) {
@ -620,7 +641,7 @@ public abstract class TransportReplicationAction<
request.shardId(), request, state.version(), primary.currentNodeId());
}
setPhase(task, "rerouted");
performAction(node, actionName, false);
performAction(node, actionName, false, request);
}
private boolean retryIfUnavailable(ClusterState state, ShardRouting primary) {
@ -671,8 +692,9 @@ public abstract class TransportReplicationAction<
}
}
private void performAction(final DiscoveryNode node, final String action, final boolean isPrimaryAction) {
transportService.sendRequest(node, action, request, transportOptions, new TransportResponseHandler<Response>() {
private void performAction(final DiscoveryNode node, final String action, final boolean isPrimaryAction,
final TransportRequest requestToPerform) {
transportService.sendRequest(node, action, requestToPerform, transportOptions, new TransportResponseHandler<Response>() {
@Override
public Response newInstance() {
@ -700,7 +722,7 @@ public abstract class TransportReplicationAction<
(org.apache.logging.log4j.util.Supplier<?>) () -> new ParameterizedMessage(
"received an error from node [{}] for request [{}], scheduling a retry",
node.getId(),
request),
requestToPerform),
exp);
retry(exp);
} else {
@ -794,7 +816,8 @@ public abstract class TransportReplicationAction<
* tries to acquire reference to {@link IndexShard} to perform a primary operation. Released after performing primary operation locally
* and replication of the operation to all replica shards is completed / failed (see {@link ReplicationOperation}).
*/
protected void acquirePrimaryShardReference(ShardId shardId, ActionListener<PrimaryShardReference> onReferenceAcquired) {
protected void acquirePrimaryShardReference(ShardId shardId, String allocationId,
ActionListener<PrimaryShardReference> onReferenceAcquired) {
IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex());
IndexShard indexShard = indexService.getShard(shardId.id());
// we may end up here if the cluster state used to route the primary is so stale that the underlying
@ -804,6 +827,10 @@ public abstract class TransportReplicationAction<
throw new ReplicationOperation.RetryOnPrimaryException(indexShard.shardId(),
"actual shard is not a primary " + indexShard.routingEntry());
}
final String actualAllocationId = indexShard.routingEntry().allocationId().getId();
if (actualAllocationId.equals(allocationId) == false) {
throw new ShardNotFoundException(shardId, "expected aID [{}] but found [{}]", allocationId, actualAllocationId);
}
ActionListener<Releasable> onAcquired = new ActionListener<Releasable>() {
@Override
@ -823,9 +850,14 @@ public abstract class TransportReplicationAction<
/**
* tries to acquire an operation on replicas. The lock is closed as soon as replication is completed on the node.
*/
protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, ActionListener<Releasable> onLockAcquired) {
protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, final String allocationId,
ActionListener<Releasable> onLockAcquired) {
IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex());
IndexShard indexShard = indexService.getShard(shardId.id());
final String actualAllocationId = indexShard.routingEntry().allocationId().getId();
if (actualAllocationId.equals(allocationId) == false) {
throw new ShardNotFoundException(shardId, "expected aID [{}] but found [{}]", allocationId, actualAllocationId);
}
indexShard.acquireReplicaOperationLock(primaryTerm, onLockAcquired, executor);
}
@ -888,7 +920,8 @@ public abstract class TransportReplicationAction<
listener.onFailure(new NoNodeAvailableException("unknown node [" + nodeId + "]"));
return;
}
transportService.sendRequest(node, transportReplicaAction, request, transportOptions,
transportService.sendRequest(node, transportReplicaAction,
new ConcreteShardRequest<>(request, replica.allocationId().getId()), transportOptions,
new ActionListenerResponseHandler<>(listener, () -> TransportResponse.Empty.INSTANCE));
}
@ -930,6 +963,72 @@ public abstract class TransportReplicationAction<
}
}
/** a wrapper class to encapsulate a request when being sent to a specific allocation id **/
final class ConcreteShardRequest<R extends TransportRequest> extends TransportRequest {
/** {@link AllocationId#getId()} of the shard this request is sent to **/
private String targetAllocationID;
private R request;
ConcreteShardRequest(Supplier<R> requestSupplier) {
request = requestSupplier.get();
// null now, but will be populated by reading from the streams
targetAllocationID = null;
}
ConcreteShardRequest(R request, String targetAllocationID) {
Objects.requireNonNull(request);
Objects.requireNonNull(targetAllocationID);
this.request = request;
this.targetAllocationID = targetAllocationID;
}
@Override
public void setParentTask(String parentTaskNode, long parentTaskId) {
request.setParentTask(parentTaskNode, parentTaskId);
}
@Override
public void setParentTask(TaskId taskId) {
request.setParentTask(taskId);
}
@Override
public TaskId getParentTask() {
return request.getParentTask();
}
@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId) {
return request.createTask(id, type, action, parentTaskId);
}
@Override
public String getDescription() {
return "[" + request.getDescription() + "] for aID [" + targetAllocationID + "]";
}
@Override
public void readFrom(StreamInput in) throws IOException {
targetAllocationID = in.readString();
request.readFrom(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(targetAllocationID);
request.writeTo(out);
}
public R getRequest() {
return request;
}
public String getTargetAllocationID() {
return targetAllocationID;
}
}
/**
* Sets the current phase on the task if it isn't null. Pulled into its own
* method because its more convenient that way.

View File

@ -27,8 +27,8 @@ import org.elasticsearch.cluster.Diffable;
import org.elasticsearch.cluster.DiffableUtils;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.cluster.routing.RecoverySource.SnapshotRecoverySource;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -88,6 +88,11 @@ public class RoutingTable implements Iterable<IndexRoutingTable>, Diffable<Routi
return indicesRouting.containsKey(index);
}
public boolean hasIndex(Index index) {
IndexRoutingTable indexRouting = index(index.getName());
return indexRouting != null && indexRouting.getIndex().equals(index);
}
public IndexRoutingTable index(String index) {
return indicesRouting.get(index);
}

View File

@ -33,10 +33,18 @@ public class ShardNotFoundException extends ResourceNotFoundException {
}
public ShardNotFoundException(ShardId shardId, Throwable ex) {
super("no such shard", ex);
setShard(shardId);
this(shardId, "no such shard", ex);
}
public ShardNotFoundException(ShardId shardId, String msg, Object... args) {
this(shardId, msg, null, args);
}
public ShardNotFoundException(ShardId shardId, String msg, Throwable ex, Object... args) {
super(msg, ex, args);
setShard(shardId);
}
public ShardNotFoundException(StreamInput in) throws IOException{
super(in);
}

View File

@ -69,6 +69,7 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.action.support.replication.TransportReplicationActionTests;
import org.elasticsearch.action.termvectors.MultiTermVectorsAction;
import org.elasticsearch.action.termvectors.MultiTermVectorsRequest;
import org.elasticsearch.action.termvectors.TermVectorsAction;
@ -117,7 +118,6 @@ import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
@ClusterScope(scope = Scope.SUITE, numClientNodes = 1, minNumDataNodes = 2)
public class IndicesRequestIT extends ESIntegTestCase {
@ -638,8 +638,7 @@ public class IndicesRequestIT extends ESIntegTestCase {
assertThat("no internal requests intercepted for action [" + action + "]", requests.size(), greaterThan(0));
}
for (TransportRequest internalRequest : requests) {
assertThat(internalRequest, instanceOf(IndicesRequest.class));
IndicesRequest indicesRequest = (IndicesRequest) internalRequest;
IndicesRequest indicesRequest = convertRequest(internalRequest);
assertThat(internalRequest.getClass().getName(), indicesRequest.indices(), equalTo(originalRequest.indices()));
assertThat(indicesRequest.indicesOptions(), equalTo(originalRequest.indicesOptions()));
}
@ -651,14 +650,24 @@ public class IndicesRequestIT extends ESIntegTestCase {
List<TransportRequest> requests = consumeTransportRequests(action);
assertThat("no internal requests intercepted for action [" + action + "]", requests.size(), greaterThan(0));
for (TransportRequest internalRequest : requests) {
assertThat(internalRequest, instanceOf(IndicesRequest.class));
for (String index : ((IndicesRequest) internalRequest).indices()) {
IndicesRequest indicesRequest = convertRequest(internalRequest);
for (String index : indicesRequest.indices()) {
assertThat(indices, hasItem(index));
}
}
}
}
static IndicesRequest convertRequest(TransportRequest request) {
final IndicesRequest indicesRequest;
if (request instanceof IndicesRequest) {
indicesRequest = (IndicesRequest) request;
} else {
indicesRequest = TransportReplicationActionTests.resolveRequest(request);
}
return indicesRequest;
}
private String randomIndexOrAlias() {
String index = randomFrom(indices);
if (randomBoolean()) {

View File

@ -26,6 +26,7 @@ import org.elasticsearch.action.support.ActiveShardCount;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.transport.NoNodeAvailableException;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ESAllocationTestCase;
import org.elasticsearch.cluster.action.shard.ShardStateAction;
import org.elasticsearch.cluster.block.ClusterBlock;
import org.elasticsearch.cluster.block.ClusterBlockException;
@ -36,6 +37,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
import org.elasticsearch.cluster.routing.RoutingNode;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.cluster.routing.TestShardRouting;
@ -47,21 +49,25 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.engine.EngineClosedException;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.IndexShardClosedException;
import org.elasticsearch.index.shard.IndexShardState;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardNotFoundException;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.cluster.ESAllocationTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseOptions;
import org.elasticsearch.transport.TransportService;
@ -75,12 +81,12 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state;
@ -93,12 +99,32 @@ import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class TransportReplicationActionTests extends ESTestCase {
/**
* takes a request that was sent by a {@link TransportReplicationAction} and captured
* and returns the underlying request if it's wrapped or the original (cast to the expected type).
*
* This will throw a {@link ClassCastException} if the request is of the wrong type.
*/
public static <R extends ReplicationRequest> R resolveRequest(TransportRequest requestOrWrappedRequest) {
if (requestOrWrappedRequest instanceof TransportReplicationAction.ConcreteShardRequest) {
requestOrWrappedRequest = ((TransportReplicationAction.ConcreteShardRequest)requestOrWrappedRequest).getRequest();
}
return (R) requestOrWrappedRequest;
}
private static ThreadPool threadPool;
private ClusterService clusterService;
@ -411,7 +437,7 @@ public class TransportReplicationActionTests extends ESTestCase {
isRelocated.set(true);
executeOnPrimary = false;
}
action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) {
action.new AsyncPrimaryAction(request, primaryShard.allocationId().getId(), createTransportChannel(listener), task) {
@Override
protected ReplicationOperation<Request, Request, Action.PrimaryResult> createReplicatedOperation(Request request,
ActionListener<Action.PrimaryResult> actionListener, Action.PrimaryShardReference primaryShardReference,
@ -452,7 +478,8 @@ public class TransportReplicationActionTests extends ESTestCase {
final String index = "test";
final ShardId shardId = new ShardId(index, "_na_", 0);
ClusterState state = state(index, true, ShardRoutingState.RELOCATING);
String primaryTargetNodeId = state.getRoutingTable().shardRoutingTable(shardId).primaryShard().relocatingNodeId();
final ShardRouting primaryShard = state.getRoutingTable().shardRoutingTable(shardId).primaryShard();
String primaryTargetNodeId = primaryShard.relocatingNodeId();
// simulate execution of the primary phase on the relocation target node
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(primaryTargetNodeId)).build();
setState(clusterService, state);
@ -460,7 +487,7 @@ public class TransportReplicationActionTests extends ESTestCase {
PlainActionFuture<Response> listener = new PlainActionFuture<>();
ReplicationTask task = maybeTask();
AtomicBoolean executed = new AtomicBoolean();
action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) {
action.new AsyncPrimaryAction(request, primaryShard.allocationId().getRelocationId(), createTransportChannel(listener), task) {
@Override
protected ReplicationOperation<Request, Request, Action.PrimaryResult> createReplicatedOperation(Request request,
ActionListener<Action.PrimaryResult> actionListener, Action.PrimaryShardReference primaryShardReference,
@ -473,6 +500,11 @@ public class TransportReplicationActionTests extends ESTestCase {
}
};
}
@Override
public void onFailure(Exception e) {
throw new RuntimeException(e);
}
}.run();
assertThat(executed.get(), equalTo(true));
assertPhase(task, "finished");
@ -596,7 +628,9 @@ public class TransportReplicationActionTests extends ESTestCase {
state = ClusterState.builder(state).metaData(metaData).build();
setState(clusterService, state);
AtomicBoolean executed = new AtomicBoolean();
action.new AsyncPrimaryAction(new Request(shardId), createTransportChannel(new PlainActionFuture<>()), null) {
ShardRouting primaryShard = state.routingTable().shardRoutingTable(shardId).primaryShard();
action.new AsyncPrimaryAction(new Request(shardId), primaryShard.allocationId().getId(),
createTransportChannel(new PlainActionFuture<>()), null) {
@Override
protected ReplicationOperation<Request, Request, Action.PrimaryResult> createReplicatedOperation(Request request,
ActionListener<Action.PrimaryResult> actionListener, Action.PrimaryShardReference primaryShardReference,
@ -613,8 +647,10 @@ public class TransportReplicationActionTests extends ESTestCase {
final String index = "test";
final ShardId shardId = new ShardId(index, "_na_", 0);
// no replica, we only want to test on primary
setState(clusterService, state(index, true, ShardRoutingState.STARTED));
final ClusterState state = state(index, true, ShardRoutingState.STARTED);
setState(clusterService, state);
logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint());
final ShardRouting primaryShard = state.routingTable().shardRoutingTable(shardId).primaryShard();
Request request = new Request(shardId);
PlainActionFuture<Response> listener = new PlainActionFuture<>();
ReplicationTask task = maybeTask();
@ -622,7 +658,7 @@ public class TransportReplicationActionTests extends ESTestCase {
final boolean throwExceptionOnCreation = i == 1;
final boolean throwExceptionOnRun = i == 2;
final boolean respondWithError = i == 3;
action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) {
action.new AsyncPrimaryAction(request, primaryShard.allocationId().getId(), createTransportChannel(listener), task) {
@Override
protected ReplicationOperation<Request, Request, Action.PrimaryResult> createReplicatedOperation(Request request,
ActionListener<Action.PrimaryResult> actionListener, Action.PrimaryShardReference primaryShardReference,
@ -666,8 +702,9 @@ public class TransportReplicationActionTests extends ESTestCase {
public void testReplicasCounter() throws Exception {
final ShardId shardId = new ShardId("test", "_na_", 0);
setState(clusterService, state(shardId.getIndexName(), true,
ShardRoutingState.STARTED, ShardRoutingState.STARTED));
final ClusterState state = state(shardId.getIndexName(), true, ShardRoutingState.STARTED, ShardRoutingState.STARTED);
setState(clusterService, state);
final ShardRouting replicaRouting = state.getRoutingTable().shardRoutingTable(shardId).replicaShards().get(0);
boolean throwException = randomBoolean();
final ReplicationTask task = maybeTask();
Action action = new Action(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool) {
@ -683,7 +720,8 @@ public class TransportReplicationActionTests extends ESTestCase {
};
final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler();
try {
replicaOperationTransportHandler.messageReceived(new Request().setShardId(shardId),
replicaOperationTransportHandler.messageReceived(
action.new ConcreteShardRequest(new Request().setShardId(shardId), replicaRouting.allocationId().getId()),
createTransportChannel(new PlainActionFuture<>()), task);
} catch (ElasticsearchException e) {
assertThat(e.getMessage(), containsString("simulated"));
@ -725,6 +763,111 @@ public class TransportReplicationActionTests extends ESTestCase {
assertEquals(ActiveShardCount.from(requestWaitForActiveShards), request.waitForActiveShards());
}
/** test that a primary request is rejected if it arrives at a shard with a wrong allocation id */
public void testPrimaryActionRejectsWrongAid() throws Exception {
final String index = "test";
final ShardId shardId = new ShardId(index, "_na_", 0);
setState(clusterService, state(index, true, ShardRoutingState.STARTED));
PlainActionFuture<Response> listener = new PlainActionFuture<>();
Request request = new Request(shardId).timeout("1ms");
action.new PrimaryOperationTransportHandler().messageReceived(
action.new ConcreteShardRequest(request, "_not_a_valid_aid_"),
createTransportChannel(listener), maybeTask()
);
try {
listener.get();
fail("using a wrong aid didn't fail the operation");
} catch (ExecutionException execException) {
Throwable throwable = execException.getCause();
logger.debug("got exception:" , throwable);
assertTrue(throwable.getClass() + " is not a retry exception", action.retryPrimaryException(throwable));
}
}
/** test that a replica request is rejected if it arrives at a shard with a wrong allocation id */
public void testReplicaActionRejectsWrongAid() throws Exception {
final String index = "test";
final ShardId shardId = new ShardId(index, "_na_", 0);
ClusterState state = state(index, false, ShardRoutingState.STARTED, ShardRoutingState.STARTED);
final ShardRouting replica = state.routingTable().shardRoutingTable(shardId).replicaShards().get(0);
// simulate execution of the node holding the replica
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
setState(clusterService, state);
PlainActionFuture<Response> listener = new PlainActionFuture<>();
Request request = new Request(shardId).timeout("1ms");
action.new ReplicaOperationTransportHandler().messageReceived(
action.new ConcreteShardRequest(request, "_not_a_valid_aid_"),
createTransportChannel(listener), maybeTask()
);
try {
listener.get();
fail("using a wrong aid didn't fail the operation");
} catch (ExecutionException execException) {
Throwable throwable = execException.getCause();
if (action.retryPrimaryException(throwable) == false) {
throw new AssertionError("thrown exception is not retriable", throwable);
}
assertThat(throwable.getMessage(), containsString("_not_a_valid_aid_"));
}
}
/**
* test throwing a {@link org.elasticsearch.action.support.replication.TransportReplicationAction.RetryOnReplicaException}
* causes a retry
*/
public void testRetryOnReplica() throws Exception {
final ShardId shardId = new ShardId("test", "_na_", 0);
ClusterState state = state(shardId.getIndexName(), true, ShardRoutingState.STARTED, ShardRoutingState.STARTED);
final ShardRouting replica = state.getRoutingTable().shardRoutingTable(shardId).replicaShards().get(0);
// simulate execution of the node holding the replica
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
setState(clusterService, state);
AtomicBoolean throwException = new AtomicBoolean(true);
final ReplicationTask task = maybeTask();
Action action = new Action(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool) {
@Override
protected ReplicaResult shardOperationOnReplica(Request request) {
assertPhase(task, "replica");
if (throwException.get()) {
throw new RetryOnReplicaException(shardId, "simulation");
}
return new ReplicaResult();
}
};
final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler();
final PlainActionFuture<Response> listener = new PlainActionFuture<>();
final Request request = new Request().setShardId(shardId);
request.primaryTerm(state.metaData().getIndexSafe(shardId.getIndex()).primaryTerm(shardId.id()));
replicaOperationTransportHandler.messageReceived(
action.new ConcreteShardRequest(request, replica.allocationId().getId()),
createTransportChannel(listener), task);
if (listener.isDone()) {
listener.get(); // fail with the exception if there
fail("listener shouldn't be done");
}
// no retry yet
List<CapturingTransport.CapturedRequest> capturedRequests =
transport.getCapturedRequestsByTargetNodeAndClear().get(replica.currentNodeId());
assertThat(capturedRequests, nullValue());
// release the waiting
throwException.set(false);
setState(clusterService, state);
capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear().get(replica.currentNodeId());
assertThat(capturedRequests, notNullValue());
assertThat(capturedRequests.size(), equalTo(1));
final CapturingTransport.CapturedRequest capturedRequest = capturedRequests.get(0);
assertThat(capturedRequest.action, equalTo("testActionWithExceptions[r]"));
assertThat(capturedRequest.request, instanceOf(TransportReplicationAction.ConcreteShardRequest.class));
assertThat(((TransportReplicationAction.ConcreteShardRequest) capturedRequest.request).getRequest(), equalTo(request));
assertThat(((TransportReplicationAction.ConcreteShardRequest) capturedRequest.request).getTargetAllocationID(),
equalTo(replica.allocationId().getId()));
}
private void assertIndexShardCounter(int expected) {
assertThat(count.get(), equalTo(expected));
}
@ -797,7 +940,7 @@ public class TransportReplicationActionTests extends ESTestCase {
Action(Settings settings, String actionName, TransportService transportService,
ClusterService clusterService,
ThreadPool threadPool) {
super(settings, actionName, transportService, clusterService, null, threadPool,
super(settings, actionName, transportService, clusterService, mockIndicesService(clusterService), threadPool,
new ShardStateAction(settings, clusterService, transportService, null, null, threadPool),
new ActionFilters(new HashSet<>()), new IndexNameExpressionResolver(Settings.EMPTY),
Request::new, Request::new, ThreadPool.Names.SAME);
@ -825,43 +968,76 @@ public class TransportReplicationActionTests extends ESTestCase {
protected boolean resolveIndex() {
return false;
}
}
@Override
protected void acquirePrimaryShardReference(ShardId shardId, ActionListener<PrimaryShardReference> onReferenceAcquired) {
final IndicesService mockIndicesService(ClusterService clusterService) {
final IndicesService indicesService = mock(IndicesService.class);
when(indicesService.indexServiceSafe(any(Index.class))).then(invocation -> {
Index index = (Index)invocation.getArguments()[0];
final ClusterState state = clusterService.state();
final IndexMetaData indexSafe = state.metaData().getIndexSafe(index);
return mockIndexService(indexSafe, clusterService);
});
when(indicesService.indexService(any(Index.class))).then(invocation -> {
Index index = (Index) invocation.getArguments()[0];
final ClusterState state = clusterService.state();
if (state.metaData().hasIndex(index.getName())) {
final IndexMetaData indexSafe = state.metaData().getIndexSafe(index);
return mockIndexService(clusterService.state().metaData().getIndexSafe(index), clusterService);
} else {
return null;
}
});
return indicesService;
}
final IndexService mockIndexService(final IndexMetaData indexMetaData, ClusterService clusterService) {
final IndexService indexService = mock(IndexService.class);
when(indexService.getShard(anyInt())).then(invocation -> {
int shard = (Integer) invocation.getArguments()[0];
final ShardId shardId = new ShardId(indexMetaData.getIndex(), shard);
if (shard > indexMetaData.getNumberOfShards()) {
throw new ShardNotFoundException(shardId);
}
return mockIndexShard(shardId, clusterService);
});
return indexService;
}
private IndexShard mockIndexShard(ShardId shardId, ClusterService clusterService) {
final IndexShard indexShard = mock(IndexShard.class);
doAnswer(invocation -> {
ActionListener<Releasable> callback = (ActionListener<Releasable>) invocation.getArguments()[0];
count.incrementAndGet();
PrimaryShardReference primaryShardReference = new PrimaryShardReference(null, null) {
@Override
public boolean isRelocated() {
return isRelocated.get();
}
@Override
public void failShard(String reason, @Nullable Exception e) {
throw new UnsupportedOperationException();
}
@Override
public ShardRouting routingEntry() {
ShardRouting shardRouting = clusterService.state().getRoutingTable()
.shardRoutingTable(shardId).primaryShard();
assert shardRouting != null;
return shardRouting;
}
@Override
public void close() {
count.decrementAndGet();
}
};
onReferenceAcquired.onResponse(primaryShardReference);
}
@Override
protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, ActionListener<Releasable> onLockAcquired) {
callback.onResponse(count::decrementAndGet);
return null;
}).when(indexShard).acquirePrimaryOperationLock(any(ActionListener.class), anyString());
doAnswer(invocation -> {
long term = (Long)invocation.getArguments()[0];
ActionListener<Releasable> callback = (ActionListener<Releasable>) invocation.getArguments()[1];
final long primaryTerm = indexShard.getPrimaryTerm();
if (term < primaryTerm) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s operation term [%d] is too old (current [%d])",
shardId, term, primaryTerm));
}
count.incrementAndGet();
onLockAcquired.onResponse(count::decrementAndGet);
}
callback.onResponse(count::decrementAndGet);
return null;
}).when(indexShard).acquireReplicaOperationLock(anyLong(), any(ActionListener.class), anyString());
when(indexShard.routingEntry()).thenAnswer(invocationOnMock -> {
final ClusterState state = clusterService.state();
final RoutingNode node = state.getRoutingNodes().node(state.nodes().getLocalNodeId());
final ShardRouting routing = node.getByShardId(shardId);
if (routing == null) {
throw new ShardNotFoundException(shardId, "shard is no longer assigned to current node");
}
return routing;
});
when(indexShard.state()).thenAnswer(invocationOnMock -> isRelocated.get() ? IndexShardState.RELOCATED : IndexShardState.STARTED);
doThrow(new AssertionError("failed shard is not supported")).when(indexShard).failShard(anyString(), any(Exception.class));
when(indexShard.getPrimaryTerm()).thenAnswer(i ->
clusterService.state().metaData().getIndexSafe(shardId.getIndex()).primaryTerm(shardId.id()));
return indexShard;
}
class NoopReplicationOperation extends ReplicationOperation<Request, Request, Action.PrimaryResult> {
@ -879,11 +1055,6 @@ public class TransportReplicationActionTests extends ESTestCase {
* Transport channel that is needed for replica operation testing.
*/
public TransportChannel createTransportChannel(final PlainActionFuture<Response> listener) {
return createTransportChannel(listener, error -> {
});
}
public TransportChannel createTransportChannel(final PlainActionFuture<Response> listener, Consumer<Throwable> consumer) {
return new TransportChannel() {
@Override
@ -908,7 +1079,6 @@ public class TransportReplicationActionTests extends ESTestCase {
@Override
public void sendResponse(Exception exception) throws IOException {
consumer.accept(exception);
listener.onFailure(exception);
}