diff --git a/server/src/main/java/org/elasticsearch/action/support/GroupedActionListener.java b/server/src/main/java/org/elasticsearch/action/support/GroupedActionListener.java index ed9b7c8d15d..532396ee609 100644 --- a/server/src/main/java/org/elasticsearch/action/support/GroupedActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/support/GroupedActionListener.java @@ -72,7 +72,10 @@ public final class GroupedActionListener implements ActionListener { @Override public void onFailure(Exception e) { if (failure.compareAndSet(null, e) == false) { - failure.get().addSuppressed(e); + failure.accumulateAndGet(e, (previous, current) -> { + previous.addSuppressed(current); + return previous; + }); } if (countDown.countDown()) { delegate.onFailure(failure.get()); diff --git a/server/src/test/java/org/elasticsearch/action/support/GroupedActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/support/GroupedActionListenerTests.java index 2af2da7ba09..9f6454d4e4b 100644 --- a/server/src/test/java/org/elasticsearch/action/support/GroupedActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/GroupedActionListenerTests.java @@ -26,10 +26,14 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.CoreMatchers.instanceOf; + public class GroupedActionListenerTests extends ESTestCase { public void testNotifications() throws InterruptedException { @@ -55,20 +59,17 @@ public class GroupedActionListenerTests extends ESTestCase { Thread[] threads = new Thread[numThreads]; CyclicBarrier barrier = new CyclicBarrier(numThreads); for (int i = 0; i < numThreads; i++) { - threads[i] = new Thread() { - @Override - public void run() { - try { - barrier.await(10, TimeUnit.SECONDS); - } catch (Exception e) { - throw new AssertionError(e); - } - int c = 0; - while((c = count.incrementAndGet()) <= groupSize) { - listener.onResponse(c-1); - } + threads[i] = new Thread(() -> { + try { + barrier.await(10, TimeUnit.SECONDS); + } catch (Exception e) { + throw new AssertionError(e); } - }; + int c = 0; + while((c = count.incrementAndGet()) <= groupSize) { + listener.onResponse(c-1); + } + }); threads[i].start(); } for (Thread t : threads) { @@ -100,11 +101,9 @@ public class GroupedActionListenerTests extends ESTestCase { excRef.set(e); } }; - Collection defaults = randomBoolean() ? Collections.singletonList(-1) : - Collections.emptyList(); + Collection defaults = randomBoolean() ? Collections.singletonList(-1) : Collections.emptyList(); int size = randomIntBetween(3, 4); - GroupedActionListener listener = new GroupedActionListener<>(result, size, - defaults); + GroupedActionListener listener = new GroupedActionListener<>(result, size, defaults); listener.onResponse(0); IOException ioException = new IOException(); RuntimeException rtException = new RuntimeException(); @@ -121,4 +120,23 @@ public class GroupedActionListenerTests extends ESTestCase { listener.onResponse(1); assertNull(resRef.get()); } + + public void testConcurrentFailures() throws InterruptedException { + AtomicReference finalException = new AtomicReference<>(); + int numGroups = randomIntBetween(10, 100); + GroupedActionListener listener = new GroupedActionListener<>( + ActionListener.wrap(r -> {}, finalException::set), numGroups, Collections.emptyList()); + ExecutorService executorService = Executors.newFixedThreadPool(numGroups); + for (int i = 0; i < numGroups; i++) { + executorService.submit(() -> listener.onFailure(new IOException())); + } + + executorService.shutdown(); + executorService.awaitTermination(10, TimeUnit.SECONDS); + + Exception exception = finalException.get(); + assertNotNull(exception); + assertThat(exception, instanceOf(IOException.class)); + assertEquals(numGroups - 1, exception.getSuppressed().length); + } }