Replace failure.get().addSuppressed with failure.accumulateAndGet() (#37649)

Also add a test for concurrent incoming failures
This commit is contained in:
Luca Cavanna 2019-01-29 14:57:33 +01:00 committed by GitHub
parent a6d4838a67
commit 42eec55837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 18 deletions

View File

@ -72,7 +72,10 @@ public final class GroupedActionListener<T> implements ActionListener<T> {
@Override
public void onFailure(Exception e) {
if (failure.compareAndSet(null, e) == false) {
failure.get().addSuppressed(e);
failure.accumulateAndGet(e, (previous, current) -> {
previous.addSuppressed(current);
return previous;
});
}
if (countDown.countDown()) {
delegate.onFailure(failure.get());

View File

@ -26,10 +26,14 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.instanceOf;
public class GroupedActionListenerTests extends ESTestCase {
public void testNotifications() throws InterruptedException {
@ -55,20 +59,17 @@ public class GroupedActionListenerTests extends ESTestCase {
Thread[] threads = new Thread[numThreads];
CyclicBarrier barrier = new CyclicBarrier(numThreads);
for (int i = 0; i < numThreads; i++) {
threads[i] = new Thread() {
@Override
public void run() {
try {
barrier.await(10, TimeUnit.SECONDS);
} catch (Exception e) {
throw new AssertionError(e);
}
int c = 0;
while((c = count.incrementAndGet()) <= groupSize) {
listener.onResponse(c-1);
}
threads[i] = new Thread(() -> {
try {
barrier.await(10, TimeUnit.SECONDS);
} catch (Exception e) {
throw new AssertionError(e);
}
};
int c = 0;
while((c = count.incrementAndGet()) <= groupSize) {
listener.onResponse(c-1);
}
});
threads[i].start();
}
for (Thread t : threads) {
@ -100,11 +101,9 @@ public class GroupedActionListenerTests extends ESTestCase {
excRef.set(e);
}
};
Collection<Integer> defaults = randomBoolean() ? Collections.singletonList(-1) :
Collections.emptyList();
Collection<Integer> defaults = randomBoolean() ? Collections.singletonList(-1) : Collections.emptyList();
int size = randomIntBetween(3, 4);
GroupedActionListener<Integer> listener = new GroupedActionListener<>(result, size,
defaults);
GroupedActionListener<Integer> listener = new GroupedActionListener<>(result, size, defaults);
listener.onResponse(0);
IOException ioException = new IOException();
RuntimeException rtException = new RuntimeException();
@ -121,4 +120,23 @@ public class GroupedActionListenerTests extends ESTestCase {
listener.onResponse(1);
assertNull(resRef.get());
}
public void testConcurrentFailures() throws InterruptedException {
AtomicReference<Exception> finalException = new AtomicReference<>();
int numGroups = randomIntBetween(10, 100);
GroupedActionListener<Void> listener = new GroupedActionListener<>(
ActionListener.wrap(r -> {}, finalException::set), numGroups, Collections.emptyList());
ExecutorService executorService = Executors.newFixedThreadPool(numGroups);
for (int i = 0; i < numGroups; i++) {
executorService.submit(() -> listener.onFailure(new IOException()));
}
executorService.shutdown();
executorService.awaitTermination(10, TimeUnit.SECONDS);
Exception exception = finalException.get();
assertNotNull(exception);
assertThat(exception, instanceOf(IOException.class));
assertEquals(numGroups - 1, exception.getSuppressed().length);
}
}