diff --git a/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java b/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java index 883d21154bd..0fa027744ac 100644 --- a/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java @@ -38,10 +38,12 @@ import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.routing.AllocationId; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.settings.Settings; @@ -53,14 +55,17 @@ import org.elasticsearch.index.IndexService; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardState; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardNotFoundException; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportChannelResponseHandler; import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; @@ -69,6 +74,7 @@ import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import java.io.IOException; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; @@ -115,9 +121,12 @@ public abstract class TransportReplicationAction< this.transportPrimaryAction = actionName + "[p]"; this.transportReplicaAction = actionName + "[r]"; transportService.registerRequestHandler(actionName, request, ThreadPool.Names.SAME, new OperationTransportHandler()); - transportService.registerRequestHandler(transportPrimaryAction, request, executor, new PrimaryOperationTransportHandler()); + transportService.registerRequestHandler(transportPrimaryAction, () -> new ConcreteShardRequest<>(request), executor, + new PrimaryOperationTransportHandler()); // we must never reject on because of thread pool capacity on replicas - transportService.registerRequestHandler(transportReplicaAction, replicaRequest, executor, true, true, + transportService.registerRequestHandler(transportReplicaAction, + () -> new ConcreteShardRequest<>(replicaRequest), + executor, true, true, new ReplicaOperationTransportHandler()); this.transportOptions = transportOptions(); @@ -163,7 +172,7 @@ public abstract class TransportReplicationAction< /** * Synchronous replica operation on nodes with replica copies. This is done under the lock form - * {@link #acquireReplicaOperationLock(ShardId, long, ActionListener)}. + * {@link #acquireReplicaOperationLock(ShardId, long, String, ActionListener)}. */ protected abstract ReplicaResult shardOperationOnReplica(ReplicaRequest shardRequest); @@ -230,33 +239,36 @@ public abstract class TransportReplicationAction< } } - class PrimaryOperationTransportHandler implements TransportRequestHandler { + class PrimaryOperationTransportHandler implements TransportRequestHandler> { @Override - public void messageReceived(final Request request, final TransportChannel channel) throws Exception { + public void messageReceived(final ConcreteShardRequest request, final TransportChannel channel) throws Exception { throw new UnsupportedOperationException("the task parameter is required for this operation"); } @Override - public void messageReceived(Request request, TransportChannel channel, Task task) { - new AsyncPrimaryAction(request, channel, (ReplicationTask) task).run(); + public void messageReceived(ConcreteShardRequest request, TransportChannel channel, Task task) { + new AsyncPrimaryAction(request.request, request.targetAllocationID, channel, (ReplicationTask) task).run(); } } class AsyncPrimaryAction extends AbstractRunnable implements ActionListener { private final Request request; + /** targetAllocationID of the shard this request is meant for */ + private final String targetAllocationID; private final TransportChannel channel; private final ReplicationTask replicationTask; - AsyncPrimaryAction(Request request, TransportChannel channel, ReplicationTask replicationTask) { + AsyncPrimaryAction(Request request, String targetAllocationID, TransportChannel channel, ReplicationTask replicationTask) { this.request = request; + this.targetAllocationID = targetAllocationID; this.channel = channel; this.replicationTask = replicationTask; } @Override protected void doRun() throws Exception { - acquirePrimaryShardReference(request.shardId(), this); + acquirePrimaryShardReference(request.shardId(), targetAllocationID, this); } @Override @@ -271,7 +283,9 @@ public abstract class TransportReplicationAction< final ShardRouting primary = primaryShardReference.routingEntry(); assert primary.relocating() : "indexShard is marked as relocated but routing isn't" + primary; DiscoveryNode relocatingNode = clusterService.state().nodes().get(primary.relocatingNodeId()); - transportService.sendRequest(relocatingNode, transportPrimaryAction, request, transportOptions, + transportService.sendRequest(relocatingNode, transportPrimaryAction, + new ConcreteShardRequest<>(request, primary.allocationId().getRelocationId()), + transportOptions, new TransportChannelResponseHandler(logger, channel, "rerouting indexing to target primary " + primary, TransportReplicationAction.this::newResponseInstance) { @@ -391,15 +405,17 @@ public abstract class TransportReplicationAction< } } - class ReplicaOperationTransportHandler implements TransportRequestHandler { + class ReplicaOperationTransportHandler implements TransportRequestHandler> { @Override - public void messageReceived(final ReplicaRequest request, final TransportChannel channel) throws Exception { + public void messageReceived(final ConcreteShardRequest request, final TransportChannel channel) + throws Exception { throw new UnsupportedOperationException("the task parameter is required for this operation"); } @Override - public void messageReceived(ReplicaRequest request, TransportChannel channel, Task task) throws Exception { - new AsyncReplicaAction(request, channel, (ReplicationTask) task).run(); + public void messageReceived(ConcreteShardRequest requestWithAID, TransportChannel channel, Task task) + throws Exception { + new AsyncReplicaAction(requestWithAID.request, requestWithAID.targetAllocationID, channel, (ReplicationTask) task).run(); } } @@ -417,6 +433,8 @@ public abstract class TransportReplicationAction< private final class AsyncReplicaAction extends AbstractRunnable implements ActionListener { private final ReplicaRequest request; + // allocation id of the replica this request is meant for + private final String targetAllocationID; private final TransportChannel channel; /** * The task on the node with the replica shard. @@ -426,10 +444,11 @@ public abstract class TransportReplicationAction< // something we want to avoid at all costs private final ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger, threadPool.getThreadContext()); - AsyncReplicaAction(ReplicaRequest request, TransportChannel channel, ReplicationTask task) { + AsyncReplicaAction(ReplicaRequest request, String targetAllocationID, TransportChannel channel, ReplicationTask task) { this.request = request; this.channel = channel; this.task = task; + this.targetAllocationID = targetAllocationID; } @Override @@ -464,7 +483,9 @@ public abstract class TransportReplicationAction< String extraMessage = "action [" + transportReplicaAction + "], request[" + request + "]"; TransportChannelResponseHandler handler = new TransportChannelResponseHandler<>(logger, channel, extraMessage, () -> TransportResponse.Empty.INSTANCE); - transportService.sendRequest(clusterService.localNode(), transportReplicaAction, request, handler); + transportService.sendRequest(clusterService.localNode(), transportReplicaAction, + new ConcreteShardRequest<>(request, targetAllocationID), + handler); } @Override @@ -501,7 +522,7 @@ public abstract class TransportReplicationAction< protected void doRun() throws Exception { setPhase(task, "replica"); assert request.shardId() != null : "request shardId must be set"; - acquireReplicaOperationLock(request.shardId(), request.primaryTerm(), this); + acquireReplicaOperationLock(request.shardId(), request.primaryTerm(), targetAllocationID, this); } /** @@ -598,7 +619,7 @@ public abstract class TransportReplicationAction< logger.trace("send action [{}] on primary [{}] for request [{}] with cluster state version [{}] to [{}] ", transportPrimaryAction, request.shardId(), request, state.version(), primary.currentNodeId()); } - performAction(node, transportPrimaryAction, true); + performAction(node, transportPrimaryAction, true, new ConcreteShardRequest<>(request, primary.allocationId().getId())); } private void performRemoteAction(ClusterState state, ShardRouting primary, DiscoveryNode node) { @@ -620,7 +641,7 @@ public abstract class TransportReplicationAction< request.shardId(), request, state.version(), primary.currentNodeId()); } setPhase(task, "rerouted"); - performAction(node, actionName, false); + performAction(node, actionName, false, request); } private boolean retryIfUnavailable(ClusterState state, ShardRouting primary) { @@ -671,8 +692,9 @@ public abstract class TransportReplicationAction< } } - private void performAction(final DiscoveryNode node, final String action, final boolean isPrimaryAction) { - transportService.sendRequest(node, action, request, transportOptions, new TransportResponseHandler() { + private void performAction(final DiscoveryNode node, final String action, final boolean isPrimaryAction, + final TransportRequest requestToPerform) { + transportService.sendRequest(node, action, requestToPerform, transportOptions, new TransportResponseHandler() { @Override public Response newInstance() { @@ -700,7 +722,7 @@ public abstract class TransportReplicationAction< (org.apache.logging.log4j.util.Supplier) () -> new ParameterizedMessage( "received an error from node [{}] for request [{}], scheduling a retry", node.getId(), - request), + requestToPerform), exp); retry(exp); } else { @@ -794,7 +816,8 @@ public abstract class TransportReplicationAction< * tries to acquire reference to {@link IndexShard} to perform a primary operation. Released after performing primary operation locally * and replication of the operation to all replica shards is completed / failed (see {@link ReplicationOperation}). */ - protected void acquirePrimaryShardReference(ShardId shardId, ActionListener onReferenceAcquired) { + protected void acquirePrimaryShardReference(ShardId shardId, String allocationId, + ActionListener onReferenceAcquired) { IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); // we may end up here if the cluster state used to route the primary is so stale that the underlying @@ -804,6 +827,10 @@ public abstract class TransportReplicationAction< throw new ReplicationOperation.RetryOnPrimaryException(indexShard.shardId(), "actual shard is not a primary " + indexShard.routingEntry()); } + final String actualAllocationId = indexShard.routingEntry().allocationId().getId(); + if (actualAllocationId.equals(allocationId) == false) { + throw new ShardNotFoundException(shardId, "expected aID [{}] but found [{}]", allocationId, actualAllocationId); + } ActionListener onAcquired = new ActionListener() { @Override @@ -823,9 +850,14 @@ public abstract class TransportReplicationAction< /** * tries to acquire an operation on replicas. The lock is closed as soon as replication is completed on the node. */ - protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, ActionListener onLockAcquired) { + protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, final String allocationId, + ActionListener onLockAcquired) { IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); + final String actualAllocationId = indexShard.routingEntry().allocationId().getId(); + if (actualAllocationId.equals(allocationId) == false) { + throw new ShardNotFoundException(shardId, "expected aID [{}] but found [{}]", allocationId, actualAllocationId); + } indexShard.acquireReplicaOperationLock(primaryTerm, onLockAcquired, executor); } @@ -888,7 +920,8 @@ public abstract class TransportReplicationAction< listener.onFailure(new NoNodeAvailableException("unknown node [" + nodeId + "]")); return; } - transportService.sendRequest(node, transportReplicaAction, request, transportOptions, + transportService.sendRequest(node, transportReplicaAction, + new ConcreteShardRequest<>(request, replica.allocationId().getId()), transportOptions, new ActionListenerResponseHandler<>(listener, () -> TransportResponse.Empty.INSTANCE)); } @@ -930,6 +963,72 @@ public abstract class TransportReplicationAction< } } + /** a wrapper class to encapsulate a request when being sent to a specific allocation id **/ + final class ConcreteShardRequest extends TransportRequest { + + /** {@link AllocationId#getId()} of the shard this request is sent to **/ + private String targetAllocationID; + + private R request; + + ConcreteShardRequest(Supplier requestSupplier) { + request = requestSupplier.get(); + // null now, but will be populated by reading from the streams + targetAllocationID = null; + } + + ConcreteShardRequest(R request, String targetAllocationID) { + Objects.requireNonNull(request); + Objects.requireNonNull(targetAllocationID); + this.request = request; + this.targetAllocationID = targetAllocationID; + } + + @Override + public void setParentTask(String parentTaskNode, long parentTaskId) { + request.setParentTask(parentTaskNode, parentTaskId); + } + + @Override + public void setParentTask(TaskId taskId) { + request.setParentTask(taskId); + } + + @Override + public TaskId getParentTask() { + return request.getParentTask(); + } + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId) { + return request.createTask(id, type, action, parentTaskId); + } + + @Override + public String getDescription() { + return "[" + request.getDescription() + "] for aID [" + targetAllocationID + "]"; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + targetAllocationID = in.readString(); + request.readFrom(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(targetAllocationID); + request.writeTo(out); + } + + public R getRequest() { + return request; + } + + public String getTargetAllocationID() { + return targetAllocationID; + } + } + /** * Sets the current phase on the task if it isn't null. Pulled into its own * method because its more convenient that way. diff --git a/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java b/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java index 2ca165308b1..2d960ce0450 100644 --- a/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java +++ b/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java @@ -27,8 +27,8 @@ import org.elasticsearch.cluster.Diffable; import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.MetaData; -import org.elasticsearch.common.Nullable; import org.elasticsearch.cluster.routing.RecoverySource.SnapshotRecoverySource; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -88,6 +88,11 @@ public class RoutingTable implements Iterable, Diffable requests = consumeTransportRequests(action); assertThat("no internal requests intercepted for action [" + action + "]", requests.size(), greaterThan(0)); for (TransportRequest internalRequest : requests) { - assertThat(internalRequest, instanceOf(IndicesRequest.class)); - for (String index : ((IndicesRequest) internalRequest).indices()) { + IndicesRequest indicesRequest = convertRequest(internalRequest); + for (String index : indicesRequest.indices()) { assertThat(indices, hasItem(index)); } } } } + static IndicesRequest convertRequest(TransportRequest request) { + final IndicesRequest indicesRequest; + if (request instanceof IndicesRequest) { + indicesRequest = (IndicesRequest) request; + } else { + indicesRequest = TransportReplicationActionTests.resolveRequest(request); + } + return indicesRequest; + } + private String randomIndexOrAlias() { String index = randomFrom(indices); if (randomBoolean()) { diff --git a/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java b/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java index 6c30f015124..c8aec623394 100644 --- a/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.support.ActiveShardCount; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.transport.NoNodeAvailableException; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ESAllocationTestCase; import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.block.ClusterBlock; import org.elasticsearch.cluster.block.ClusterBlockException; @@ -36,6 +37,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; +import org.elasticsearch.cluster.routing.RoutingNode; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.TestShardRouting; @@ -47,21 +49,25 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.EngineClosedException; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardClosedException; +import org.elasticsearch.index.shard.IndexShardState; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardNotFoundException; +import org.elasticsearch.indices.IndicesService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.cluster.ESAllocationTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.CapturingTransport; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseOptions; import org.elasticsearch.transport.TransportService; @@ -75,12 +81,12 @@ import java.io.IOException; import java.util.Arrays; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.stream.Collectors; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state; @@ -93,12 +99,32 @@ import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TransportReplicationActionTests extends ESTestCase { + /** + * takes a request that was sent by a {@link TransportReplicationAction} and captured + * and returns the underlying request if it's wrapped or the original (cast to the expected type). + * + * This will throw a {@link ClassCastException} if the request is of the wrong type. + */ + public static R resolveRequest(TransportRequest requestOrWrappedRequest) { + if (requestOrWrappedRequest instanceof TransportReplicationAction.ConcreteShardRequest) { + requestOrWrappedRequest = ((TransportReplicationAction.ConcreteShardRequest)requestOrWrappedRequest).getRequest(); + } + return (R) requestOrWrappedRequest; + } + private static ThreadPool threadPool; private ClusterService clusterService; @@ -411,7 +437,7 @@ public class TransportReplicationActionTests extends ESTestCase { isRelocated.set(true); executeOnPrimary = false; } - action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) { + action.new AsyncPrimaryAction(request, primaryShard.allocationId().getId(), createTransportChannel(listener), task) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -452,7 +478,8 @@ public class TransportReplicationActionTests extends ESTestCase { final String index = "test"; final ShardId shardId = new ShardId(index, "_na_", 0); ClusterState state = state(index, true, ShardRoutingState.RELOCATING); - String primaryTargetNodeId = state.getRoutingTable().shardRoutingTable(shardId).primaryShard().relocatingNodeId(); + final ShardRouting primaryShard = state.getRoutingTable().shardRoutingTable(shardId).primaryShard(); + String primaryTargetNodeId = primaryShard.relocatingNodeId(); // simulate execution of the primary phase on the relocation target node state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(primaryTargetNodeId)).build(); setState(clusterService, state); @@ -460,7 +487,7 @@ public class TransportReplicationActionTests extends ESTestCase { PlainActionFuture listener = new PlainActionFuture<>(); ReplicationTask task = maybeTask(); AtomicBoolean executed = new AtomicBoolean(); - action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) { + action.new AsyncPrimaryAction(request, primaryShard.allocationId().getRelocationId(), createTransportChannel(listener), task) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -473,6 +500,11 @@ public class TransportReplicationActionTests extends ESTestCase { } }; } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } }.run(); assertThat(executed.get(), equalTo(true)); assertPhase(task, "finished"); @@ -596,7 +628,9 @@ public class TransportReplicationActionTests extends ESTestCase { state = ClusterState.builder(state).metaData(metaData).build(); setState(clusterService, state); AtomicBoolean executed = new AtomicBoolean(); - action.new AsyncPrimaryAction(new Request(shardId), createTransportChannel(new PlainActionFuture<>()), null) { + ShardRouting primaryShard = state.routingTable().shardRoutingTable(shardId).primaryShard(); + action.new AsyncPrimaryAction(new Request(shardId), primaryShard.allocationId().getId(), + createTransportChannel(new PlainActionFuture<>()), null) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -613,8 +647,10 @@ public class TransportReplicationActionTests extends ESTestCase { final String index = "test"; final ShardId shardId = new ShardId(index, "_na_", 0); // no replica, we only want to test on primary - setState(clusterService, state(index, true, ShardRoutingState.STARTED)); + final ClusterState state = state(index, true, ShardRoutingState.STARTED); + setState(clusterService, state); logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint()); + final ShardRouting primaryShard = state.routingTable().shardRoutingTable(shardId).primaryShard(); Request request = new Request(shardId); PlainActionFuture listener = new PlainActionFuture<>(); ReplicationTask task = maybeTask(); @@ -622,7 +658,7 @@ public class TransportReplicationActionTests extends ESTestCase { final boolean throwExceptionOnCreation = i == 1; final boolean throwExceptionOnRun = i == 2; final boolean respondWithError = i == 3; - action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) { + action.new AsyncPrimaryAction(request, primaryShard.allocationId().getId(), createTransportChannel(listener), task) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -666,8 +702,9 @@ public class TransportReplicationActionTests extends ESTestCase { public void testReplicasCounter() throws Exception { final ShardId shardId = new ShardId("test", "_na_", 0); - setState(clusterService, state(shardId.getIndexName(), true, - ShardRoutingState.STARTED, ShardRoutingState.STARTED)); + final ClusterState state = state(shardId.getIndexName(), true, ShardRoutingState.STARTED, ShardRoutingState.STARTED); + setState(clusterService, state); + final ShardRouting replicaRouting = state.getRoutingTable().shardRoutingTable(shardId).replicaShards().get(0); boolean throwException = randomBoolean(); final ReplicationTask task = maybeTask(); Action action = new Action(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool) { @@ -683,7 +720,8 @@ public class TransportReplicationActionTests extends ESTestCase { }; final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler(); try { - replicaOperationTransportHandler.messageReceived(new Request().setShardId(shardId), + replicaOperationTransportHandler.messageReceived( + action.new ConcreteShardRequest(new Request().setShardId(shardId), replicaRouting.allocationId().getId()), createTransportChannel(new PlainActionFuture<>()), task); } catch (ElasticsearchException e) { assertThat(e.getMessage(), containsString("simulated")); @@ -725,6 +763,111 @@ public class TransportReplicationActionTests extends ESTestCase { assertEquals(ActiveShardCount.from(requestWaitForActiveShards), request.waitForActiveShards()); } + /** test that a primary request is rejected if it arrives at a shard with a wrong allocation id */ + public void testPrimaryActionRejectsWrongAid() throws Exception { + final String index = "test"; + final ShardId shardId = new ShardId(index, "_na_", 0); + setState(clusterService, state(index, true, ShardRoutingState.STARTED)); + PlainActionFuture listener = new PlainActionFuture<>(); + Request request = new Request(shardId).timeout("1ms"); + action.new PrimaryOperationTransportHandler().messageReceived( + action.new ConcreteShardRequest(request, "_not_a_valid_aid_"), + createTransportChannel(listener), maybeTask() + ); + try { + listener.get(); + fail("using a wrong aid didn't fail the operation"); + } catch (ExecutionException execException) { + Throwable throwable = execException.getCause(); + logger.debug("got exception:" , throwable); + assertTrue(throwable.getClass() + " is not a retry exception", action.retryPrimaryException(throwable)); + } + } + + /** test that a replica request is rejected if it arrives at a shard with a wrong allocation id */ + public void testReplicaActionRejectsWrongAid() throws Exception { + final String index = "test"; + final ShardId shardId = new ShardId(index, "_na_", 0); + ClusterState state = state(index, false, ShardRoutingState.STARTED, ShardRoutingState.STARTED); + final ShardRouting replica = state.routingTable().shardRoutingTable(shardId).replicaShards().get(0); + // simulate execution of the node holding the replica + state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build(); + setState(clusterService, state); + + PlainActionFuture listener = new PlainActionFuture<>(); + Request request = new Request(shardId).timeout("1ms"); + action.new ReplicaOperationTransportHandler().messageReceived( + action.new ConcreteShardRequest(request, "_not_a_valid_aid_"), + createTransportChannel(listener), maybeTask() + ); + try { + listener.get(); + fail("using a wrong aid didn't fail the operation"); + } catch (ExecutionException execException) { + Throwable throwable = execException.getCause(); + if (action.retryPrimaryException(throwable) == false) { + throw new AssertionError("thrown exception is not retriable", throwable); + } + assertThat(throwable.getMessage(), containsString("_not_a_valid_aid_")); + } + } + + /** + * test throwing a {@link org.elasticsearch.action.support.replication.TransportReplicationAction.RetryOnReplicaException} + * causes a retry + */ + public void testRetryOnReplica() throws Exception { + final ShardId shardId = new ShardId("test", "_na_", 0); + ClusterState state = state(shardId.getIndexName(), true, ShardRoutingState.STARTED, ShardRoutingState.STARTED); + final ShardRouting replica = state.getRoutingTable().shardRoutingTable(shardId).replicaShards().get(0); + // simulate execution of the node holding the replica + state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build(); + setState(clusterService, state); + AtomicBoolean throwException = new AtomicBoolean(true); + final ReplicationTask task = maybeTask(); + Action action = new Action(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool) { + @Override + protected ReplicaResult shardOperationOnReplica(Request request) { + assertPhase(task, "replica"); + if (throwException.get()) { + throw new RetryOnReplicaException(shardId, "simulation"); + } + return new ReplicaResult(); + } + }; + final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler(); + final PlainActionFuture listener = new PlainActionFuture<>(); + final Request request = new Request().setShardId(shardId); + request.primaryTerm(state.metaData().getIndexSafe(shardId.getIndex()).primaryTerm(shardId.id())); + replicaOperationTransportHandler.messageReceived( + action.new ConcreteShardRequest(request, replica.allocationId().getId()), + createTransportChannel(listener), task); + if (listener.isDone()) { + listener.get(); // fail with the exception if there + fail("listener shouldn't be done"); + } + + // no retry yet + List capturedRequests = + transport.getCapturedRequestsByTargetNodeAndClear().get(replica.currentNodeId()); + assertThat(capturedRequests, nullValue()); + + // release the waiting + throwException.set(false); + setState(clusterService, state); + + capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear().get(replica.currentNodeId()); + assertThat(capturedRequests, notNullValue()); + assertThat(capturedRequests.size(), equalTo(1)); + final CapturingTransport.CapturedRequest capturedRequest = capturedRequests.get(0); + assertThat(capturedRequest.action, equalTo("testActionWithExceptions[r]")); + assertThat(capturedRequest.request, instanceOf(TransportReplicationAction.ConcreteShardRequest.class)); + assertThat(((TransportReplicationAction.ConcreteShardRequest) capturedRequest.request).getRequest(), equalTo(request)); + assertThat(((TransportReplicationAction.ConcreteShardRequest) capturedRequest.request).getTargetAllocationID(), + equalTo(replica.allocationId().getId())); + } + + private void assertIndexShardCounter(int expected) { assertThat(count.get(), equalTo(expected)); } @@ -797,7 +940,7 @@ public class TransportReplicationActionTests extends ESTestCase { Action(Settings settings, String actionName, TransportService transportService, ClusterService clusterService, ThreadPool threadPool) { - super(settings, actionName, transportService, clusterService, null, threadPool, + super(settings, actionName, transportService, clusterService, mockIndicesService(clusterService), threadPool, new ShardStateAction(settings, clusterService, transportService, null, null, threadPool), new ActionFilters(new HashSet<>()), new IndexNameExpressionResolver(Settings.EMPTY), Request::new, Request::new, ThreadPool.Names.SAME); @@ -825,43 +968,76 @@ public class TransportReplicationActionTests extends ESTestCase { protected boolean resolveIndex() { return false; } + } - @Override - protected void acquirePrimaryShardReference(ShardId shardId, ActionListener onReferenceAcquired) { + final IndicesService mockIndicesService(ClusterService clusterService) { + final IndicesService indicesService = mock(IndicesService.class); + when(indicesService.indexServiceSafe(any(Index.class))).then(invocation -> { + Index index = (Index)invocation.getArguments()[0]; + final ClusterState state = clusterService.state(); + final IndexMetaData indexSafe = state.metaData().getIndexSafe(index); + return mockIndexService(indexSafe, clusterService); + }); + when(indicesService.indexService(any(Index.class))).then(invocation -> { + Index index = (Index) invocation.getArguments()[0]; + final ClusterState state = clusterService.state(); + if (state.metaData().hasIndex(index.getName())) { + final IndexMetaData indexSafe = state.metaData().getIndexSafe(index); + return mockIndexService(clusterService.state().metaData().getIndexSafe(index), clusterService); + } else { + return null; + } + }); + return indicesService; + } + + final IndexService mockIndexService(final IndexMetaData indexMetaData, ClusterService clusterService) { + final IndexService indexService = mock(IndexService.class); + when(indexService.getShard(anyInt())).then(invocation -> { + int shard = (Integer) invocation.getArguments()[0]; + final ShardId shardId = new ShardId(indexMetaData.getIndex(), shard); + if (shard > indexMetaData.getNumberOfShards()) { + throw new ShardNotFoundException(shardId); + } + return mockIndexShard(shardId, clusterService); + }); + return indexService; + } + + private IndexShard mockIndexShard(ShardId shardId, ClusterService clusterService) { + final IndexShard indexShard = mock(IndexShard.class); + doAnswer(invocation -> { + ActionListener callback = (ActionListener) invocation.getArguments()[0]; count.incrementAndGet(); - PrimaryShardReference primaryShardReference = new PrimaryShardReference(null, null) { - @Override - public boolean isRelocated() { - return isRelocated.get(); - } - - @Override - public void failShard(String reason, @Nullable Exception e) { - throw new UnsupportedOperationException(); - } - - @Override - public ShardRouting routingEntry() { - ShardRouting shardRouting = clusterService.state().getRoutingTable() - .shardRoutingTable(shardId).primaryShard(); - assert shardRouting != null; - return shardRouting; - } - - @Override - public void close() { - count.decrementAndGet(); - } - }; - - onReferenceAcquired.onResponse(primaryShardReference); - } - - @Override - protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, ActionListener onLockAcquired) { + callback.onResponse(count::decrementAndGet); + return null; + }).when(indexShard).acquirePrimaryOperationLock(any(ActionListener.class), anyString()); + doAnswer(invocation -> { + long term = (Long)invocation.getArguments()[0]; + ActionListener callback = (ActionListener) invocation.getArguments()[1]; + final long primaryTerm = indexShard.getPrimaryTerm(); + if (term < primaryTerm) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s operation term [%d] is too old (current [%d])", + shardId, term, primaryTerm)); + } count.incrementAndGet(); - onLockAcquired.onResponse(count::decrementAndGet); - } + callback.onResponse(count::decrementAndGet); + return null; + }).when(indexShard).acquireReplicaOperationLock(anyLong(), any(ActionListener.class), anyString()); + when(indexShard.routingEntry()).thenAnswer(invocationOnMock -> { + final ClusterState state = clusterService.state(); + final RoutingNode node = state.getRoutingNodes().node(state.nodes().getLocalNodeId()); + final ShardRouting routing = node.getByShardId(shardId); + if (routing == null) { + throw new ShardNotFoundException(shardId, "shard is no longer assigned to current node"); + } + return routing; + }); + when(indexShard.state()).thenAnswer(invocationOnMock -> isRelocated.get() ? IndexShardState.RELOCATED : IndexShardState.STARTED); + doThrow(new AssertionError("failed shard is not supported")).when(indexShard).failShard(anyString(), any(Exception.class)); + when(indexShard.getPrimaryTerm()).thenAnswer(i -> + clusterService.state().metaData().getIndexSafe(shardId.getIndex()).primaryTerm(shardId.id())); + return indexShard; } class NoopReplicationOperation extends ReplicationOperation { @@ -879,11 +1055,6 @@ public class TransportReplicationActionTests extends ESTestCase { * Transport channel that is needed for replica operation testing. */ public TransportChannel createTransportChannel(final PlainActionFuture listener) { - return createTransportChannel(listener, error -> { - }); - } - - public TransportChannel createTransportChannel(final PlainActionFuture listener, Consumer consumer) { return new TransportChannel() { @Override @@ -908,7 +1079,6 @@ public class TransportReplicationActionTests extends ESTestCase { @Override public void sendResponse(Exception exception) throws IOException { - consumer.accept(exception); listener.onFailure(exception); }