diff --git a/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java b/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java index a18a1c805c3..7ab591a7d3d 100644 --- a/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java +++ b/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java @@ -19,6 +19,7 @@ package org.apache.druid.java.util.common.guava; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import com.google.common.collect.Ordering; import org.apache.druid.java.util.common.RE; @@ -81,6 +82,7 @@ public class ParallelMergeCombiningSequence extends YieldingSequenceBase private final int parallelism; private final long targetTimeNanos; private final Consumer metricsReporter; + private final CancellationGizmo cancellationGizmo; public ParallelMergeCombiningSequence( @@ -152,6 +154,12 @@ public class ParallelMergeCombiningSequence extends YieldingSequenceBase return finalOutSequence.toYielder(initValue, accumulator); } + @VisibleForTesting + public CancellationGizmo getCancellationGizmo() + { + return cancellationGizmo; + } + /** * Create an output {@link Sequence} that wraps the output {@link BlockingQueue} of a * {@link MergeCombinePartitioningAction} @@ -166,6 +174,7 @@ public class ParallelMergeCombiningSequence extends YieldingSequenceBase return new BaseSequence<>( new BaseSequence.IteratorMaker>() { + private boolean shouldCancelOnCleanup = true; @Override public Iterator make() { @@ -201,6 +210,7 @@ public class ParallelMergeCombiningSequence extends YieldingSequenceBase } if (currentBatch.isTerminalResult()) { + shouldCancelOnCleanup = false; return false; } return true; @@ -228,7 +238,9 @@ public class ParallelMergeCombiningSequence extends YieldingSequenceBase @Override public void cleanup(Iterator iterFromMake) { - // nothing to cleanup + if (shouldCancelOnCleanup) { + cancellationGizmo.cancel(new RuntimeException("Already closed")); + } } } ); diff --git a/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java b/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java index 8e2b4e5025c..e459c38db89 100644 --- a/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java +++ b/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java @@ -540,6 +540,32 @@ public class ParallelMergeCombiningSequenceTest assertException(input, 8, 64, 1000, 500); } + @Test + public void testGracefulCloseOfYielderCancelsPool() throws Exception + { + + List> input = new ArrayList<>(); + input.add(nonBlockingSequence(10_000)); + input.add(nonBlockingSequence(9_001)); + input.add(nonBlockingSequence(7_777)); + input.add(nonBlockingSequence(8_500)); + input.add(nonBlockingSequence(5_000)); + input.add(nonBlockingSequence(8_888)); + + assertResultWithEarlyClose(input, 128, 1024, 256, reportMetrics -> { + Assert.assertEquals(2, reportMetrics.getParallelism()); + Assert.assertEquals(6, reportMetrics.getInputSequences()); + // 49166 is total set of results if yielder were fully processed, expect somewhere more than 0 but less than that + // this isn't super indicative of anything really, since closing the yielder would have triggered the baggage + // to run, which runs this metrics reporter function, while the actual processing could still be occuring on the + // pool in the background and the yielder still operates as intended if cancellation isn't in fact happening. + // other tests ensure that this is true though (yielder.next throwing an exception for example) + Assert.assertTrue(49166 > reportMetrics.getInputRows()); + Assert.assertTrue(0 < reportMetrics.getInputRows()); + }); + } + + private void assertResult(List> sequences) throws InterruptedException, IOException { assertResult( @@ -611,6 +637,87 @@ public class ParallelMergeCombiningSequenceTest Assert.assertEquals(0, pool.getRunningThreadCount()); combiningYielder.close(); parallelMergeCombineYielder.close(); + // cancellation trigger should not be set if sequence was fully yielded and close is called + // (though shouldn't actually matter even if it was...) + Assert.assertFalse(parallelMergeCombineSequence.getCancellationGizmo().isCancelled()); + } + + private void assertResultWithEarlyClose( + List> sequences, + int batchSize, + int yieldAfter, + int closeYielderAfter, + Consumer reporter + ) + throws InterruptedException, IOException + { + final CombiningSequence combiningSequence = CombiningSequence.create( + new MergeSequence<>(INT_PAIR_ORDERING, Sequences.simple(sequences)), + INT_PAIR_ORDERING, + INT_PAIR_MERGE_FN + ); + + final ParallelMergeCombiningSequence parallelMergeCombineSequence = new ParallelMergeCombiningSequence<>( + pool, + sequences, + INT_PAIR_ORDERING, + INT_PAIR_MERGE_FN, + true, + 5000, + 0, + TEST_POOL_SIZE, + yieldAfter, + batchSize, + ParallelMergeCombiningSequence.DEFAULT_TASK_TARGET_RUN_TIME_MILLIS, + reporter + ); + + Yielder combiningYielder = Yielders.each(combiningSequence); + Yielder parallelMergeCombineYielder = Yielders.each(parallelMergeCombineSequence); + + IntPair prev = null; + + int yields = 0; + while (!combiningYielder.isDone() && !parallelMergeCombineYielder.isDone()) { + if (yields >= closeYielderAfter) { + parallelMergeCombineYielder.close(); + combiningYielder.close(); + break; + } else { + yields++; + Assert.assertEquals(combiningYielder.get(), parallelMergeCombineYielder.get()); + Assert.assertNotEquals(parallelMergeCombineYielder.get(), prev); + prev = parallelMergeCombineYielder.get(); + combiningYielder = combiningYielder.next(combiningYielder.get()); + parallelMergeCombineYielder = parallelMergeCombineYielder.next(parallelMergeCombineYielder.get()); + } + } + // trying to next the yielder creates sadness for you + final String expectedExceptionMsg = "Already closed"; + try { + Assert.assertEquals(combiningYielder.get(), parallelMergeCombineYielder.get()); + parallelMergeCombineYielder.next(parallelMergeCombineYielder.get()); + // this should explode so the contradictory next statement should not be reached + Assert.assertTrue(false); + } + catch (RuntimeException rex) { + Assert.assertEquals(expectedExceptionMsg, rex.getMessage()); + } + + // cancellation gizmo of sequence should be cancelled, and also should contain our expected message + Assert.assertTrue(parallelMergeCombineSequence.getCancellationGizmo().isCancelled()); + Assert.assertEquals( + expectedExceptionMsg, + parallelMergeCombineSequence.getCancellationGizmo().getRuntimeException().getMessage() + ); + + while (pool.getRunningThreadCount() > 0) { + Thread.sleep(100); + } + Assert.assertEquals(0, pool.getRunningThreadCount()); + + Assert.assertFalse(combiningYielder.isDone()); + Assert.assertFalse(parallelMergeCombineYielder.isDone()); } private void assertException(List> sequences) throws Exception