diff --git a/src/main/java/org/elasticsearch/common/util/CancellableThreads.java b/src/main/java/org/elasticsearch/common/util/CancellableThreads.java new file mode 100644 index 00000000000..23313de5172 --- /dev/null +++ b/src/main/java/org/elasticsearch/common/util/CancellableThreads.java @@ -0,0 +1,145 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.common.util; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Nullable; + +import java.util.HashSet; +import java.util.Set; + +/** + * A utility class for multi threaded operation that needs to be cancellable via interrupts. Every cancellable operation should be + * executed via {@link #execute(Interruptable)}, which will capture the executing thread and make sure it is interrupted in the case + * cancellation. + */ +public class CancellableThreads { + private final Set threads = new HashSet<>(); + private boolean cancelled = false; + private String reason; + + public synchronized boolean isCancelled() { + return cancelled; + } + + + /** call this will throw an exception if operation was cancelled. Override {@link #onCancel(String, java.lang.Throwable)} for custom failure logic */ + public synchronized void checkForCancel() { + if (isCancelled()) { + onCancel(reason, null); + } + } + + /** + * called if {@link #checkForCancel()} was invoked after the operation was cancelled. + * the default implementation always throws an {@link ExecutionCancelledException}, suppressing + * any other exception that occurred before cancellation + * + * @param reason reason for failure supplied by the caller of {@link @cancel} + * @param suppressedException any error that was encountered during the execution before the operation was cancelled. + */ + protected void onCancel(String reason, @Nullable Throwable suppressedException) { + RuntimeException e = new ExecutionCancelledException("operation was cancelled reason [" + reason + "]"); + if (suppressedException != null) { + e.addSuppressed(suppressedException); + } + throw e; + } + + private synchronized boolean add() { + checkForCancel(); + threads.add(Thread.currentThread()); + // capture and clean the interrupted thread before we start, so we can identify + // our own interrupt. we do so under lock so we know we don't clear our own. + return Thread.interrupted(); + } + + /** + * run the Interruptable, capturing the executing thread. Concurrent calls to {@link #cancel(String)} will interrupt this thread + * causing the call to prematurely return. + * + * @param interruptable code to run + */ + public void execute(Interruptable interruptable) { + boolean wasInterrupted = add(); + RuntimeException throwable = null; + try { + interruptable.run(); + } catch (InterruptedException e) { + // assume this is us and ignore + } catch (RuntimeException t) { + throwable = t; + } finally { + remove(); + } + // we are now out of threads collection so we can't be interrupted any more by this class + // restore old flag and see if we need to fail + if (wasInterrupted) { + Thread.currentThread().interrupt(); + } else { + // clear the flag interrupted flag as we are checking for failure.. + Thread.interrupted(); + } + synchronized (this) { + if (isCancelled()) { + onCancel(reason, throwable); + } else if (throwable != null) { + // if we're not canceling, we throw the original exception + throw throwable; + } + } + } + + + private synchronized void remove() { + threads.remove(Thread.currentThread()); + } + + /** cancel all current running operations. Future calls to {@link #checkForCancel()} will be failed with the given reason */ + public synchronized void cancel(String reason) { + if (cancelled) { + // we were already cancelled, make sure we don't interrupt threads twice + // this is important in order to make sure that we don't mark + // Thread.interrupted without handling it + return; + } + cancelled = true; + this.reason = reason; + for (Thread thread : threads) { + thread.interrupt(); + } + threads.clear(); + } + + + public interface Interruptable { + public void run() throws InterruptedException; + } + + public class ExecutionCancelledException extends ElasticsearchException { + + public ExecutionCancelledException(String msg) { + super(msg); + } + + public ExecutionCancelledException(String msg, Throwable cause) { + super(msg, cause); + } + } +} diff --git a/src/main/java/org/elasticsearch/indices/recovery/RecoveriesCollection.java b/src/main/java/org/elasticsearch/indices/recovery/RecoveriesCollection.java index ed8667eb11b..16ffe8fe98b 100644 --- a/src/main/java/org/elasticsearch/indices/recovery/RecoveriesCollection.java +++ b/src/main/java/org/elasticsearch/indices/recovery/RecoveriesCollection.java @@ -23,9 +23,9 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.logging.ESLogger; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardClosedException; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.index.shard.IndexShard; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; diff --git a/src/main/java/org/elasticsearch/indices/recovery/RecoveryStatus.java b/src/main/java/org/elasticsearch/indices/recovery/RecoveryStatus.java index 0ae8ea26a99..b4d93b5716c 100644 --- a/src/main/java/org/elasticsearch/indices/recovery/RecoveryStatus.java +++ b/src/main/java/org/elasticsearch/indices/recovery/RecoveryStatus.java @@ -21,15 +21,15 @@ package org.elasticsearch.indices.recovery; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.IOUtils; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.logging.ESLogger; import org.elasticsearch.common.logging.Loggers; +import org.elasticsearch.common.util.CancellableThreads; import org.elasticsearch.common.util.concurrent.AbstractRefCounted; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; -import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.store.Store; import org.elasticsearch.index.store.StoreFileMetaData; @@ -40,7 +40,6 @@ import java.util.Map.Entry; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; /** * @@ -64,13 +63,13 @@ public class RecoveryStatus extends AbstractRefCounted { private final Store store; private final RecoveryTarget.RecoveryListener listener; - private AtomicReference waitingRecoveryThread = new AtomicReference<>(); - private final AtomicBoolean finished = new AtomicBoolean(); private final ConcurrentMap openIndexOutputs = ConcurrentCollections.newConcurrentMap(); private final Store.LegacyChecksums legacyChecksums = new Store.LegacyChecksums(); + private final CancellableThreads cancellableThreads = new CancellableThreads(); + public RecoveryStatus(IndexShard indexShard, DiscoveryNode sourceNode, RecoveryState state, RecoveryTarget.RecoveryListener listener) { super("recovery_status"); this.recoveryId = idGenerator.incrementAndGet(); @@ -110,25 +109,15 @@ public class RecoveryStatus extends AbstractRefCounted { return state; } + public CancellableThreads CancellableThreads() { + return cancellableThreads; + } + public Store store() { ensureRefCount(); return store; } - /** set a thread that should be interrupted if the recovery is canceled */ - public void setWaitingRecoveryThread(Thread thread) { - ensureRefCount(); - waitingRecoveryThread.set(thread); - } - - /** - * clear the thread set by {@link #setWaitingRecoveryThread(Thread)}, making sure we - * do not override another thread. - */ - public void clearWaitingRecoveryThread(Thread threadToClear) { - waitingRecoveryThread.compareAndSet(threadToClear, null); - } - public void stage(RecoveryState.Stage stage) { state.setStage(stage); } @@ -147,22 +136,21 @@ public class RecoveryStatus extends AbstractRefCounted { store.renameFilesSafe(tempFileNames); } - /** cancel the recovery. calling this method will clean temporary files and release the store + /** + * cancel the recovery. calling this method will clean temporary files and release the store * unless this object is in use (in which case it will be cleaned once all ongoing users call * {@link #decRef()} - * - * if {@link #setWaitingRecoveryThread(Thread)} was used, the thread will be interrupted. + *

+ * if {@link #CancellableThreads()} was used, the threads will be interrupted. */ public void cancel(String reason) { if (finished.compareAndSet(false, true)) { - logger.debug("recovery canceled (reason: [{}])", reason); - // release the initial reference. recovery files will be cleaned as soon as ref count goes to zero, potentially now - decRef(); - - final Thread thread = waitingRecoveryThread.get(); - if (thread != null) { - logger.debug("interrupting recovery thread on canceled recovery"); - thread.interrupt(); + try { + logger.debug("recovery canceled (reason: [{}])", reason); + cancellableThreads.cancel(reason); + } finally { + // release the initial reference. recovery files will be cleaned as soon as ref count goes to zero, potentially now + decRef(); } } } @@ -170,16 +158,20 @@ public class RecoveryStatus extends AbstractRefCounted { /** * fail the recovery and call listener * - * @param e exception that encapsulating the failure + * @param e exception that encapsulating the failure * @param sendShardFailure indicates whether to notify the master of the shard failure - **/ + */ public void fail(RecoveryFailedException e, boolean sendShardFailure) { if (finished.compareAndSet(false, true)) { try { listener.onRecoveryFailure(state, e, sendShardFailure); } finally { - // release the initial reference. recovery files will be cleaned as soon as ref count goes to zero, potentially now - decRef(); + try { + cancellableThreads.cancel("failed recovery [" + e.getMessage() + "]"); + } finally { + // release the initial reference. recovery files will be cleaned as soon as ref count goes to zero, potentially now + decRef(); + } } } } @@ -246,7 +238,12 @@ public class RecoveryStatus extends AbstractRefCounted { Iterator> iterator = openIndexOutputs.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry entry = iterator.next(); - IOUtils.closeWhileHandlingException(entry.getValue()); + logger.trace("closing IndexOutput file [{}]", entry.getValue()); + try { + entry.getValue().close(); + } catch (Throwable t) { + logger.debug("error while closing recovery output [{}]", t, entry.getValue()); + } iterator.remove(); } // trash temporary files diff --git a/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java b/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java index 5bfaf8346bb..4053d6c29c9 100644 --- a/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java +++ b/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java @@ -33,14 +33,11 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.CancellableThreads; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.index.IndexShardMissingException; import org.elasticsearch.index.engine.RecoveryEngineException; -import org.elasticsearch.index.shard.IllegalIndexShardStateException; -import org.elasticsearch.index.shard.IndexShardClosedException; -import org.elasticsearch.index.shard.IndexShardNotStartedException; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.index.shard.*; import org.elasticsearch.index.store.Store; import org.elasticsearch.index.store.StoreFileMetaData; import org.elasticsearch.index.translog.Translog; @@ -51,6 +48,7 @@ import org.elasticsearch.transport.*; import java.util.Collections; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.common.unit.TimeValue.timeValueMillis; @@ -160,22 +158,27 @@ public class RecoveryTarget extends AbstractComponent { new RecoveryFailedException(recoveryStatus.state(), "failed to list local files", e), true); return; } - StartRecoveryRequest request = new StartRecoveryRequest(recoveryStatus.shardId(), recoveryStatus.sourceNode(), clusterService.localNode(), + final StartRecoveryRequest request = new StartRecoveryRequest(recoveryStatus.shardId(), recoveryStatus.sourceNode(), clusterService.localNode(), false, existingFiles, recoveryStatus.state().getType(), recoveryStatus.recoveryId()); + final AtomicReference responseHolder = new AtomicReference<>(); try { logger.trace("[{}][{}] starting recovery from {}", request.shardId().index().name(), request.shardId().id(), request.sourceNode()); StopWatch stopWatch = new StopWatch().start(); - recoveryStatus.setWaitingRecoveryThread(Thread.currentThread()); - - RecoveryResponse recoveryResponse = transportService.submitRequest(request.sourceNode(), RecoverySource.Actions.START_RECOVERY, request, new FutureTransportResponseHandler() { + recoveryStatus.CancellableThreads().execute(new CancellableThreads.Interruptable() { @Override - public RecoveryResponse newInstance() { - return new RecoveryResponse(); + public void run() throws InterruptedException { + responseHolder.set(transportService.submitRequest(request.sourceNode(), RecoverySource.Actions.START_RECOVERY, request, new FutureTransportResponseHandler() { + @Override + public RecoveryResponse newInstance() { + return new RecoveryResponse(); + } + }).txGet()); } - }).txGet(); - recoveryStatus.clearWaitingRecoveryThread(Thread.currentThread()); + }); + final RecoveryResponse recoveryResponse = responseHolder.get(); + assert responseHolder != null; stopWatch.stop(); if (logger.isTraceEnabled()) { StringBuilder sb = new StringBuilder(); @@ -197,6 +200,8 @@ public class RecoveryTarget extends AbstractComponent { } // do this through ongoing recoveries to remove it from the collection onGoingRecoveries.markRecoveryAsDone(recoveryStatus.recoveryId()); + } catch (CancellableThreads.ExecutionCancelledException e) { + logger.trace("recovery cancelled", e); } catch (Throwable e) { if (logger.isTraceEnabled()) { logger.trace("[{}][{}] Got exception on recovery", e, request.shardId().index().name(), request.shardId().id()); @@ -478,8 +483,6 @@ public class RecoveryTarget extends AbstractComponent { try { doRecovery(statusRef.status()); } finally { - // make sure we never interrupt the thread after we have released it back to the pool - statusRef.status().clearWaitingRecoveryThread(Thread.currentThread()); statusRef.close(); } } diff --git a/src/main/java/org/elasticsearch/indices/recovery/ShardRecoveryHandler.java b/src/main/java/org/elasticsearch/indices/recovery/ShardRecoveryHandler.java index ad32499a8ba..a45b3a01ee4 100644 --- a/src/main/java/org/elasticsearch/indices/recovery/ShardRecoveryHandler.java +++ b/src/main/java/org/elasticsearch/indices/recovery/ShardRecoveryHandler.java @@ -40,16 +40,18 @@ import org.elasticsearch.common.compress.CompressorFactory; import org.elasticsearch.common.logging.ESLogger; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.CancellableThreads; +import org.elasticsearch.common.util.CancellableThreads.Interruptable; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.index.IndexService; import org.elasticsearch.index.deletionpolicy.SnapshotIndexCommit; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.mapper.DocumentMapper; -import org.elasticsearch.index.IndexService; import org.elasticsearch.index.shard.IllegalIndexShardStateException; +import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardClosedException; import org.elasticsearch.index.shard.IndexShardState; -import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.store.Store; import org.elasticsearch.index.store.StoreFileMetaData; import org.elasticsearch.index.translog.Translog; @@ -59,9 +61,7 @@ import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; -import java.util.HashSet; import java.util.List; -import java.util.Set; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicReference; @@ -88,14 +88,19 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { private final MappingUpdatedAction mappingUpdatedAction; private final RecoveryResponse response; - private final CancelableThreads cancelableThreads = new CancelableThreads() { + private final CancellableThreads cancellableThreads = new CancellableThreads() { @Override - protected void fail(String reason) { + protected void onCancel(String reason, @Nullable Throwable suppressedException) { + RuntimeException e; if (shard.state() == IndexShardState.CLOSED) { // check if the shard got closed on us - throw new IndexShardClosedException(shard.shardId(), "shard is closed and recovery was canceled reason [" + reason + "]"); + e = new IndexShardClosedException(shard.shardId(), "shard is closed and recovery was canceled reason [" + reason + "]"); } else { - throw new ElasticsearchException("recovery was canceled reason [" + reason + "]"); + e = new ExecutionCancelledException("recovery was canceled reason [" + reason + "]"); } + if (suppressedException != null) { + e.addSuppressed(suppressedException); + } + throw e; } }; @@ -141,7 +146,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { */ @Override public void phase1(final SnapshotIndexCommit snapshot) throws ElasticsearchException { - cancelableThreads.failIfCanceled(); + cancellableThreads.checkForCancel(); // Total size of segment files that are recovered long totalSize = 0; // Total size of segment files that were able to be re-used @@ -191,7 +196,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { logger.trace("[{}][{}] recovery [phase1] to {}: recovering_files [{}] with total_size [{}], reusing_files [{}] with total_size [{}]", indexName, shardId, request.targetNode(), response.phase1FileNames.size(), new ByteSizeValue(totalSize), response.phase1ExistingFileNames.size(), new ByteSizeValue(existingTotalSize)); - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { RecoveryFilesInfoRequest recoveryInfoFilesRequest = new RecoveryFilesInfoRequest(request.recoveryId(), request.shardId(), @@ -245,7 +250,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { @Override protected void doRun() { - cancelableThreads.failIfCanceled(); + cancellableThreads.checkForCancel(); store.incRef(); final StoreFileMetaData md = recoverySourceMetadata.get(name); try (final IndexInput indexInput = store.directory().openInput(name, IOContext.READONCE)) { @@ -279,7 +284,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { final BytesArray content = new BytesArray(buf, 0, toRead); readCount += toRead; final boolean lastChunk = readCount == len; - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { // Actually send the file chunk to the target node, waiting for it to complete @@ -319,7 +324,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { fileIndex++; } - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { // Wait for all files that need to be transferred to finish transferring @@ -333,7 +338,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { ExceptionsHelper.rethrowAndSuppress(exceptions); } - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { // Send the CLEAN_FILES request, which takes all of the files that @@ -377,10 +382,10 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { if (shard.state() == IndexShardState.CLOSED) { throw new IndexShardClosedException(request.shardId()); } - cancelableThreads.failIfCanceled(); + cancellableThreads.checkForCancel(); logger.trace("{} recovery [phase2] to {}: start", request.shardId(), request.targetNode()); StopWatch stopWatch = new StopWatch().start(); - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { // Send a request preparing the new shard's translog to receive @@ -428,14 +433,14 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { if (shard.state() == IndexShardState.CLOSED) { throw new IndexShardClosedException(request.shardId()); } - cancelableThreads.failIfCanceled(); + cancellableThreads.checkForCancel(); logger.trace("[{}][{}] recovery [phase3] to {}: sending transaction log operations", indexName, shardId, request.targetNode()); StopWatch stopWatch = new StopWatch().start(); // Send the translog operations to the target node int totalOperations = sendSnapshot(snapshot); - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { // Send the FINALIZE request to the target node. The finalize request @@ -512,7 +517,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { latch.countDown(); } }); - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { latch.await(); @@ -537,7 +542,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { for (DocumentMapper documentMapper : documentMappersToUpdate) { mappingUpdatedAction.updateMappingOnMaster(indexService.index().getName(), documentMapper, indexService.indexUUID(), listener); } - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { try { @@ -577,7 +582,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { if (shard.state() == IndexShardState.CLOSED) { throw new IndexShardClosedException(request.shardId()); } - cancelableThreads.failIfCanceled(); + cancellableThreads.checkForCancel(); operations.add(operation); ops += 1; size += operation.estimateSize(); @@ -596,7 +601,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { // recoverySettings.rateLimiter().pause(size); // } - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { final RecoveryTranslogOperationsRequest translogOperationsRequest = new RecoveryTranslogOperationsRequest(request.recoveryId(), request.shardId(), operations); @@ -613,7 +618,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { } // send the leftover if (!operations.isEmpty()) { - cancelableThreads.run(new Interruptable() { + cancellableThreads.execute(new Interruptable() { @Override public void run() throws InterruptedException { RecoveryTranslogOperationsRequest translogOperationsRequest = new RecoveryTranslogOperationsRequest(request.recoveryId(), request.shardId(), operations); @@ -630,62 +635,7 @@ public final class ShardRecoveryHandler implements Engine.RecoveryHandler { * Cancels the recovery and interrupts all eligible threads. */ public void cancel(String reason) { - cancelableThreads.cancel(reason); - } - - private static abstract class CancelableThreads { - private final Set threads = new HashSet<>(); - private boolean canceled = false; - private String reason; - - public synchronized boolean isCanceled() { - return canceled; - } - - - public synchronized void failIfCanceled() { - if (isCanceled()) { - fail(reason); - } - } - - protected abstract void fail(String reason); - - private synchronized void add() { - failIfCanceled(); - threads.add(Thread.currentThread()); - } - - public void run(Interruptable interruptable) { - add(); - try { - interruptable.run(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } finally { - remove(); - } - } - - private synchronized void remove() { - threads.remove(Thread.currentThread()); - failIfCanceled(); - } - - public synchronized void cancel(String reason) { - canceled = true; - this.reason = reason; - for (Thread thread : threads) { - thread.interrupt(); - } - threads.clear(); - } - - - } - - interface Interruptable { - public void run() throws InterruptedException; + cancellableThreads.cancel(reason); } @Override diff --git a/src/test/java/org/elasticsearch/common/util/CancellableThreadsTest.java b/src/test/java/org/elasticsearch/common/util/CancellableThreadsTest.java new file mode 100644 index 00000000000..66db5241b7e --- /dev/null +++ b/src/test/java/org/elasticsearch/common/util/CancellableThreadsTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.common.util; + +import org.elasticsearch.common.util.CancellableThreads.Interruptable; +import org.elasticsearch.test.ElasticsearchTestCase; +import org.hamcrest.Matchers; +import org.junit.Test; + +import java.util.concurrent.CountDownLatch; + +public class CancellableThreadsTest extends ElasticsearchTestCase { + + private static class CustomException extends RuntimeException { + + public CustomException(String msg) { + super(msg); + } + } + + private class TestPlan { + public final int id; + public final boolean busySpin; + public final boolean exceptBeforeCancel; + public final boolean exitBeforeCancel; + public final boolean exceptAfterCancel; + public final boolean presetInterrupt; + + private TestPlan(int id) { + this.id = id; + this.busySpin = randomBoolean(); + this.exceptBeforeCancel = randomBoolean(); + this.exitBeforeCancel = randomBoolean(); + this.exceptAfterCancel = randomBoolean(); + this.presetInterrupt = randomBoolean(); + } + } + + + @Test + public void testCancellableThreads() throws InterruptedException { + Thread[] threads = new Thread[randomIntBetween(3, 10)]; + final TestPlan[] plans = new TestPlan[threads.length]; + final Throwable[] throwables = new Throwable[threads.length]; + final boolean[] interrupted = new boolean[threads.length]; + final CancellableThreads cancellableThreads = new CancellableThreads(); + final CountDownLatch readyForCancel = new CountDownLatch(threads.length); + for (int i = 0; i < threads.length; i++) { + final TestPlan plan = new TestPlan(i); + plans[i] = plan; + threads[i] = new Thread(new Runnable() { + @Override + public void run() { + try { + if (plan.presetInterrupt) { + Thread.currentThread().interrupt(); + } + cancellableThreads.execute(new Interruptable() { + @Override + public void run() throws InterruptedException { + assertFalse("interrupt thread should have been clear", Thread.currentThread().isInterrupted()); + if (plan.exceptBeforeCancel) { + throw new CustomException("thread [" + plan.id + "] pre-cancel exception"); + } else if (plan.exitBeforeCancel) { + return; + } + readyForCancel.countDown(); + try { + if (plan.busySpin) { + while (!Thread.currentThread().isInterrupted()) { + } + } else { + Thread.sleep(50000); + } + } finally { + if (plan.exceptAfterCancel) { + throw new CustomException("thread [" + plan.id + "] post-cancel exception"); + } + } + } + }); + } catch (Throwable t) { + throwables[plan.id] = t; + } + if (plan.exceptBeforeCancel || plan.exitBeforeCancel) { + // we have to mark we're ready now (actually done). + readyForCancel.countDown(); + } + interrupted[plan.id] = Thread.currentThread().isInterrupted(); + + } + }); + threads[i].setDaemon(true); + threads[i].start(); + } + + readyForCancel.await(); + cancellableThreads.cancel("test"); + for (Thread thread : threads) { + thread.join(20000); + assertFalse(thread.isAlive()); + } + for (int i = 0; i < threads.length; i++) { + TestPlan plan = plans[i]; + if (plan.exceptBeforeCancel) { + assertThat(throwables[i], Matchers.instanceOf(CustomException.class)); + } else if (plan.exitBeforeCancel) { + assertNull(throwables[i]); + } else { + // in all other cases, we expect a cancellation exception. + assertThat(throwables[i], Matchers.instanceOf(CancellableThreads.ExecutionCancelledException.class)); + if (plan.exceptAfterCancel) { + assertThat(throwables[i].getSuppressed(), + Matchers.arrayContaining( + Matchers.instanceOf(CustomException.class) + )); + } else { + assertThat(throwables[i].getSuppressed(), Matchers.emptyArray()); + } + } + assertThat(interrupted[plan.id], Matchers.equalTo(plan.presetInterrupt)); + } + } + +}