Extract TransportRequestDeduplication from ShardStateAction (#37870)

* Extracted the logic for master request duplication so it can be reused by the snapshotting logic
* Removed custom listener used by `ShardStateAction` to not leak these into future users of this class
* Changed semantics slightly to get rid of redundant instantiations of the composite listener
* Relates #37686
This commit is contained in:
Armin Braun 2019-01-30 19:21:09 +01:00 committed by GitHub
parent 6500b0cbd7
commit a070b8acc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 241 additions and 206 deletions

View File

@ -1192,12 +1192,12 @@ public abstract class TransportReplicationAction<
onSuccess.run();
}
protected final ShardStateAction.Listener createShardActionListener(final Runnable onSuccess,
protected final ActionListener<Void> createShardActionListener(final Runnable onSuccess,
final Consumer<Exception> onPrimaryDemoted,
final Consumer<Exception> onIgnoredFailure) {
return new ShardStateAction.Listener() {
return new ActionListener<Void>() {
@Override
public void onSuccess() {
public void onResponse(Void aVoid) {
onSuccess.run();
}

View File

@ -25,6 +25,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateObserver;
@ -48,18 +49,17 @@ import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.RemoteTransportException;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestDeduplicator;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;
@ -71,7 +71,6 @@ import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Predicate;
import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
@ -89,7 +88,7 @@ public class ShardStateAction {
// a list of shards that failed during replication
// we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard.
private final ConcurrentMap<FailedShardEntry, CompositeListener> remoteFailedShardsCache = ConcurrentCollections.newConcurrentMap();
private final TransportRequestDeduplicator<FailedShardEntry> remoteFailedShardsDeduplicator = new TransportRequestDeduplicator<>();
@Inject
public ShardStateAction(ClusterService clusterService, TransportService transportService,
@ -106,7 +105,7 @@ public class ShardStateAction {
}
private void sendShardAction(final String actionName, final ClusterState currentState,
final TransportRequest request, final Listener listener) {
final TransportRequest request, final ActionListener<Void> listener) {
ClusterStateObserver observer =
new ClusterStateObserver(currentState, clusterService, null, logger, threadPool.getThreadContext());
DiscoveryNode masterNode = currentState.nodes().getMasterNode();
@ -120,7 +119,7 @@ public class ShardStateAction {
actionName, request, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
listener.onSuccess();
listener.onResponse(null);
}
@Override
@ -163,44 +162,22 @@ public class ShardStateAction {
* @param listener callback upon completion of the request
*/
public void remoteShardFailed(final ShardId shardId, String allocationId, long primaryTerm, boolean markAsStale, final String message,
@Nullable final Exception failure, Listener listener) {
@Nullable final Exception failure, ActionListener<Void> listener) {
assert primaryTerm > 0L : "primary term should be strictly positive";
final FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale);
final CompositeListener compositeListener = new CompositeListener(listener);
final CompositeListener existingListener = remoteFailedShardsCache.putIfAbsent(shardEntry, compositeListener);
if (existingListener == null) {
sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, new Listener() {
@Override
public void onSuccess() {
try {
compositeListener.onSuccess();
} finally {
remoteFailedShardsCache.remove(shardEntry);
}
}
@Override
public void onFailure(Exception e) {
try {
compositeListener.onFailure(e);
} finally {
remoteFailedShardsCache.remove(shardEntry);
}
}
});
} else {
existingListener.addListener(listener);
}
remoteFailedShardsDeduplicator.executeOnce(
new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale), listener,
(req, reqListener) -> sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), req, reqListener));
}
int remoteShardFailedCacheSize() {
return remoteFailedShardsCache.size();
return remoteFailedShardsDeduplicator.size();
}
/**
* Send a shard failed request to the master node to update the cluster state when a shard on the local node failed.
*/
public void localShardFailed(final ShardRouting shardRouting, final String message,
@Nullable final Exception failure, Listener listener) {
@Nullable final Exception failure, ActionListener<Void> listener) {
localShardFailed(shardRouting, message, failure, listener, clusterService.state());
}
@ -208,7 +185,7 @@ public class ShardStateAction {
* Send a shard failed request to the master node to update the cluster state when a shard on the local node failed.
*/
public void localShardFailed(final ShardRouting shardRouting, final String message, @Nullable final Exception failure,
Listener listener, final ClusterState currentState) {
ActionListener<Void> listener, final ClusterState currentState) {
FailedShardEntry shardEntry = new FailedShardEntry(shardRouting.shardId(), shardRouting.allocationId().getId(),
0L, message, failure, true);
sendShardAction(SHARD_FAILED_ACTION_NAME, currentState, shardEntry, listener);
@ -216,7 +193,8 @@ public class ShardStateAction {
// visible for testing
protected void waitForNewMasterAndRetry(String actionName, ClusterStateObserver observer,
TransportRequest request, Listener listener, Predicate<ClusterState> changePredicate) {
TransportRequest request, ActionListener<Void> listener,
Predicate<ClusterState> changePredicate) {
observer.waitForNextChange(new ClusterStateObserver.Listener() {
@Override
public void onNewClusterState(ClusterState state) {
@ -497,14 +475,14 @@ public class ShardStateAction {
public void shardStarted(final ShardRouting shardRouting,
final long primaryTerm,
final String message,
final Listener listener) {
final ActionListener<Void> listener) {
shardStarted(shardRouting, primaryTerm, message, listener, clusterService.state());
}
public void shardStarted(final ShardRouting shardRouting,
final long primaryTerm,
final String message,
final Listener listener,
final ActionListener<Void> listener,
final ClusterState currentState) {
StartedShardEntry entry = new StartedShardEntry(shardRouting.shardId(), shardRouting.allocationId().getId(), primaryTerm, message);
sendShardAction(SHARD_STARTED_ACTION_NAME, currentState, entry, listener);
@ -670,97 +648,6 @@ public class ShardStateAction {
}
}
public interface Listener {
default void onSuccess() {
}
/**
* Notification for non-channel exceptions that are not handled
* by {@link ShardStateAction}.
*
* The exceptions that are handled by {@link ShardStateAction}
* are:
* - {@link NotMasterException}
* - {@link NodeDisconnectedException}
* - {@link FailedToCommitClusterStateException}
*
* Any other exception is communicated to the requester via
* this notification.
*
* @param e the unexpected cause of the failure on the master
*/
default void onFailure(final Exception e) {
}
}
/**
* A composite listener that allows registering multiple listeners dynamically.
*/
static final class CompositeListener implements Listener {
private boolean isNotified = false;
private Exception failure = null;
private final List<Listener> listeners = new ArrayList<>();
CompositeListener(Listener listener) {
listeners.add(listener);
}
void addListener(Listener listener) {
final boolean ready;
synchronized (this) {
ready = this.isNotified;
if (ready == false) {
listeners.add(listener);
}
}
if (ready) {
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onSuccess();
}
}
}
private void onCompleted(Exception failure) {
synchronized (this) {
this.failure = failure;
this.isNotified = true;
}
RuntimeException firstException = null;
for (Listener listener : listeners) {
try {
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onSuccess();
}
} catch (RuntimeException innerEx) {
if (firstException == null) {
firstException = innerEx;
} else {
firstException.addSuppressed(innerEx);
}
}
}
if (firstException != null) {
throw firstException;
}
}
@Override
public void onSuccess() {
onCompleted(null);
}
@Override
public void onFailure(Exception failure) {
onCompleted(failure);
}
}
public static class NoLongerPrimaryShardException extends ElasticsearchException {
public NoLongerPrimaryShardException(ShardId shardId, String msg) {

View File

@ -109,8 +109,7 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent imple
private final ShardStateAction shardStateAction;
private final NodeMappingRefreshAction nodeMappingRefreshAction;
private static final ShardStateAction.Listener SHARD_STATE_ACTION_LISTENER = new ShardStateAction.Listener() {
};
private static final ActionListener<Void> SHARD_STATE_ACTION_LISTENER = ActionListener.wrap(() -> {});
private final Settings settings;
// a list of shards that failed during recovery

View File

@ -0,0 +1,114 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiConsumer;
/**
* Deduplicator for {@link TransportRequest}s that keeps track of {@link TransportRequest}s that should
* not be sent in parallel.
* @param <T> Transport Request Class
*/
public final class TransportRequestDeduplicator<T extends TransportRequest> {
private final ConcurrentMap<T, CompositeListener> requests = ConcurrentCollections.newConcurrentMap();
/**
* Ensures a given request not executed multiple times when another equal request is already in-flight.
* If the request is not yet known to the deduplicator it will invoke the passed callback with an {@link ActionListener}
* that must be completed by the caller when the request completes. Once that listener is completed the request will be removed from
* the deduplicator's internal state. If the request is already known to the deduplicator it will keep
* track of the given listener and invoke it when the listener passed to the callback on first invocation is completed.
* @param request Request to deduplicate
* @param listener Listener to invoke on request completion
* @param callback Callback to be invoked with request and completion listener the first time the request is added to the deduplicator
*/
public void executeOnce(T request, ActionListener<Void> listener, BiConsumer<T, ActionListener<Void>> callback) {
ActionListener<Void> completionListener = requests.computeIfAbsent(request, CompositeListener::new).addListener(listener);
if (completionListener != null) {
callback.accept(request, completionListener);
}
}
public int size() {
return requests.size();
}
private final class CompositeListener implements ActionListener<Void> {
private final List<ActionListener<Void>> listeners = new ArrayList<>();
private final T request;
private boolean isNotified;
private Exception failure;
CompositeListener(T request) {
this.request = request;
}
CompositeListener addListener(ActionListener<Void> listener) {
synchronized (this) {
if (this.isNotified == false) {
listeners.add(listener);
return listeners.size() == 1 ? this : null;
}
}
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onResponse(null);
}
return null;
}
private void onCompleted(Exception failure) {
synchronized (this) {
this.failure = failure;
this.isNotified = true;
}
try {
if (failure == null) {
ActionListener.onResponse(listeners, null);
} else {
ActionListener.onFailure(listeners, failure);
}
} finally {
requests.remove(request);
}
}
@Override
public void onResponse(final Void aVoid) {
onCompleted(null);
}
@Override
public void onFailure(Exception failure) {
onCompleted(failure);
}
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.cluster.action.shard;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.replication.ClusterStateCreationUtils;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateObserver;
@ -75,7 +76,6 @@ import static org.elasticsearch.test.ClusterServiceUtils.setState;
import static org.elasticsearch.test.VersionUtils.randomCompatibleVersion;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.sameInstance;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
@ -110,7 +110,7 @@ public class ShardStateActionTests extends ESTestCase {
@Override
protected void waitForNewMasterAndRetry(String actionName, ClusterStateObserver observer, TransportRequest request,
Listener listener, Predicate<ClusterState> changePredicate) {
ActionListener<Void> listener, Predicate<ClusterState> changePredicate) {
onBeforeWaitForNewMasterAndRetry.run();
super.waitForNewMasterAndRetry(actionName, observer, request, listener, changePredicate);
onAfterWaitForNewMasterAndRetry.run();
@ -197,9 +197,9 @@ public class ShardStateActionTests extends ESTestCase {
});
ShardRouting failedShard = getRandomShardRouting(index);
shardStateAction.localShardFailed(failedShard, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
shardStateAction.localShardFailed(failedShard, "test", getSimulatedFailure(), new ActionListener<Void>() {
@Override
public void onSuccess() {
public void onResponse(Void aVoid) {
success.set(true);
latch.countDown();
}
@ -246,9 +246,9 @@ public class ShardStateActionTests extends ESTestCase {
setUpMasterRetryVerification(numberOfRetries, retries, latch, retryLoop);
ShardRouting failedShard = getRandomShardRouting(index);
shardStateAction.localShardFailed(failedShard, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
shardStateAction.localShardFailed(failedShard, "test", getSimulatedFailure(), new ActionListener<Void>() {
@Override
public void onSuccess() {
public void onResponse(Void aVoid) {
success.set(true);
latch.countDown();
}
@ -343,9 +343,9 @@ public class ShardStateActionTests extends ESTestCase {
long primaryTerm = randomLongBetween(1, Long.MAX_VALUE);
for (int i = 0; i < numListeners; i++) {
shardStateAction.remoteShardFailed(failedShard.shardId(), failedShard.allocationId().getId(),
primaryTerm, markAsStale, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
primaryTerm, markAsStale, "test", getSimulatedFailure(), new ActionListener<Void>() {
@Override
public void onSuccess() {
public void onResponse(Void aVoid) {
latch.countDown();
}
@Override
@ -394,9 +394,9 @@ public class ShardStateActionTests extends ESTestCase {
ShardRouting failedShard = randomFrom(failedShards);
shardStateAction.remoteShardFailed(failedShard.shardId(), failedShard.allocationId().getId(),
randomLongBetween(1, Long.MAX_VALUE), randomBoolean(), "test", getSimulatedFailure(),
new ShardStateAction.Listener() {
new ActionListener<Void>() {
@Override
public void onSuccess() {
public void onResponse(Void aVoid) {
notifiedResponses.incrementAndGet();
}
@Override
@ -561,70 +561,13 @@ public class ShardStateActionTests extends ESTestCase {
}
}
public void testCompositeListener() throws Exception {
AtomicInteger successCount = new AtomicInteger();
AtomicInteger failureCount = new AtomicInteger();
Exception failure = randomBoolean() ? getSimulatedFailure() : null;
ShardStateAction.CompositeListener compositeListener = new ShardStateAction.CompositeListener(new ShardStateAction.Listener() {
@Override
public void onSuccess() {
successCount.incrementAndGet();
}
@Override
public void onFailure(Exception e) {
assertThat(e, sameInstance(failure));
failureCount.incrementAndGet();
}
});
int iterationsPerThread = scaledRandomIntBetween(100, 1000);
Thread[] threads = new Thread[between(1, 4)];
Phaser barrier = new Phaser(threads.length + 1);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
barrier.arriveAndAwaitAdvance();
for (int n = 0; n < iterationsPerThread; n++) {
compositeListener.addListener(new ShardStateAction.Listener() {
@Override
public void onSuccess() {
successCount.incrementAndGet();
}
@Override
public void onFailure(Exception e) {
assertThat(e, sameInstance(failure));
failureCount.incrementAndGet();
}
});
}
});
threads[i].start();
}
barrier.arriveAndAwaitAdvance();
if (failure != null) {
compositeListener.onFailure(failure);
} else {
compositeListener.onSuccess();
}
for (Thread t : threads) {
t.join();
}
assertBusy(() -> {
if (failure != null) {
assertThat(successCount.get(), equalTo(0));
assertThat(failureCount.get(), equalTo(threads.length*iterationsPerThread + 1));
} else {
assertThat(successCount.get(), equalTo(threads.length*iterationsPerThread + 1));
assertThat(failureCount.get(), equalTo(0));
}
});
}
private static class TestListener implements ShardStateAction.Listener {
private static class TestListener implements ActionListener<Void> {
private final SetOnce<Exception> failure = new SetOnce<>();
private final CountDownLatch latch = new CountDownLatch(1);
@Override
public void onSuccess() {
public void onResponse(Void aVoid) {
try {
failure.set(null);
} finally {

View File

@ -22,6 +22,7 @@ package org.elasticsearch.discovery;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.index.CorruptIndexException;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexResponse;
@ -317,10 +318,10 @@ public class ClusterDisruptionIT extends AbstractDisruptionTestCase {
setDisruptionScheme(networkDisruption);
networkDisruption.startDisrupting();
service.localShardFailed(failedShard, "simulated", new CorruptIndexException("simulated", (String) null), new
ShardStateAction.Listener() {
service.localShardFailed(failedShard, "simulated", new CorruptIndexException("simulated", (String) null),
new ActionListener<Void>() {
@Override
public void onSuccess() {
public void onResponse(final Void aVoid) {
success.set(true);
latch.countDown();
}

View File

@ -0,0 +1,91 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.sameInstance;
public class TransportRequestDeduplicatorTests extends ESTestCase {
public void testRequestDeduplication() throws Exception {
AtomicInteger successCount = new AtomicInteger();
AtomicInteger failureCount = new AtomicInteger();
Exception failure = randomBoolean() ? new TransportException("simulated") : null;
final TransportRequest request = new TransportRequest() {
@Override
public void setParentTask(final TaskId taskId) {
}
};
final TransportRequestDeduplicator<TransportRequest> deduplicator = new TransportRequestDeduplicator<>();
final SetOnce<ActionListener<Void>> listenerHolder = new SetOnce<>();
int iterationsPerThread = scaledRandomIntBetween(100, 1000);
Thread[] threads = new Thread[between(1, 4)];
Phaser barrier = new Phaser(threads.length + 1);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
barrier.arriveAndAwaitAdvance();
for (int n = 0; n < iterationsPerThread; n++) {
deduplicator.executeOnce(request, new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
successCount.incrementAndGet();
}
@Override
public void onFailure(Exception e) {
assertThat(e, sameInstance(failure));
failureCount.incrementAndGet();
}
}, (req, reqListener) -> listenerHolder.set(reqListener));
}
});
threads[i].start();
}
barrier.arriveAndAwaitAdvance();
for (Thread t : threads) {
t.join();
}
final ActionListener<Void> listener = listenerHolder.get();
assertThat(deduplicator.size(), equalTo(1));
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onResponse(null);
}
assertThat(deduplicator.size(), equalTo(0));
assertBusy(() -> {
if (failure != null) {
assertThat(successCount.get(), equalTo(0));
assertThat(failureCount.get(), equalTo(threads.length * iterationsPerThread));
} else {
assertThat(successCount.get(), equalTo(threads.length * iterationsPerThread));
assertThat(failureCount.get(), equalTo(0));
}
});
}
}