diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java index 6e91b19df4d..04cdab3b1fe 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java @@ -144,6 +144,9 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor if (needToProcessBatch()) { runAllOpsOnBatch(); + if (inputChannel.isFinished()) { + return ReturnOrAwait.runAgain(); + } flushAllRowsAndCols(); } return ReturnOrAwait.runAgain(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java index 02cb02360d9..5d1b350ca92 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java @@ -156,6 +156,90 @@ public class WindowOperatorQueryFrameProcessorTest extends FrameProcessorTestBas } } + @Test + public void testOutputChannelReachingCapacity() throws IOException + { + // This test validates that we don't end up writing multiple (2) frames to the output channel while reading from the input channel, + // in the scenario when the input channel has finished and receiver's completed() gets called. + final ReadableInput factChannel = buildWindowTestInputChannel(); + + RowSignature inputSignature = RowSignature.builder() + .add("cityName", ColumnType.STRING) + .add("added", ColumnType.LONG) + .build(); + + FrameReader frameReader = FrameReader.create(inputSignature); + + RowSignature outputSignature = RowSignature.builder() + .addAll(inputSignature) + .add("w0", ColumnType.LONG) + .build(); + + final WindowOperatorQuery query = new WindowOperatorQuery( + new QueryDataSource( + Druids.newScanQueryBuilder() + .dataSource(new TableDataSource("test")) + .intervals(new LegacySegmentSpec(Intervals.ETERNITY)) + .columns("cityName", "added") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(new HashMap<>()) + .build()), + new LegacySegmentSpec(Intervals.ETERNITY), + new HashMap<>( + // This ends up satisfying the criteria of needToProcessBatch() method, + // so we end up processing the rows we've read, hence writing the 1st frame to the output channel. + ImmutableMap.of(MultiStageQueryContext.MAX_ROWS_MATERIALIZED_IN_WINDOW, 12) + ), + outputSignature, + ImmutableList.of( + new WindowOperatorFactory(new WindowRowNumberProcessor("w0")) + ), + ImmutableList.of() + ); + + final FrameWriterFactory frameWriterFactory = new LimitedFrameWriterFactory( + FrameWriters.makeRowBasedFrameWriterFactory( + new ArenaMemoryAllocatorFactory(1 << 20), + outputSignature, + Collections.emptyList(), + false + ), + INPUT_ROWS.size() / 4 // This forces frameWriter's capacity to be reached, hence requiring another write. + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + final WindowOperatorQueryFrameProcessor processor = new WindowOperatorQueryFrameProcessor( + query, + factChannel.getChannel(), + outputChannel.writable(), + frameWriterFactory, + frameReader, + new ObjectMapper(), + ImmutableList.of( + new WindowOperatorFactory(new WindowRowNumberProcessor("w0")) + ) + ); + + exec.runFully(processor, null); + + final Sequence> rowsFromProcessor = FrameTestUtil.readRowsFromFrameChannel( + outputChannel.readable(), + FrameReader.create(outputSignature) + ); + + List> outputRows = rowsFromProcessor.toList(); + Assert.assertEquals(INPUT_ROWS.size(), outputRows.size()); + + for (int i = 0; i < INPUT_ROWS.size(); i++) { + Map inputRow = INPUT_ROWS.get(i); + List outputRow = outputRows.get(i); + + Assert.assertEquals("cityName should match", inputRow.get("cityName"), outputRow.get(0)); + Assert.assertEquals("added should match", inputRow.get("added"), outputRow.get(1)); + Assert.assertEquals("row_number() should be correct", (long) i + 1, outputRow.get(2)); + } + } + @Test public void testProcessorRun() throws Exception {