diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java index a0572a91b4d..d5f6393d738 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessor.java @@ -83,7 +83,6 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor private final FrameWriterFactory frameWriterFactory; private final FrameReader frameReader; private final int maxRowsMaterialized; - private long currentAllocatorCapacity; // Used for generating FrameRowTooLargeException if needed private Cursor frameCursor = null; private Supplier rowSupplierFromFrameCursor; private ResultRow outputRow = null; @@ -99,6 +98,8 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor private final ArrayList rowsToProcess; private int lastPartitionIndex = -1; + final AtomicInteger rowId = new AtomicInteger(0); + public WindowOperatorQueryFrameProcessor( WindowOperatorQuery query, ReadableFrameChannel inputChannel, @@ -155,7 +156,7 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor } @Override - public ReturnOrAwait runIncrementally(IntSet readableInputs) + public ReturnOrAwait runIncrementally(IntSet readableInputs) throws IOException { /* There are 2 scenarios: @@ -216,32 +217,54 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor Most of the window operations like SUM(), RANK(), RANGE() etc. can be made with 2 passes of the data. We might think to reimplement them in the MSQ way so that we do not have to materialize so much data. */ + // If there are rows pending flush, flush them and run again before processing any more rows. + if (frameHasRowsPendingFlush()) { + flushAllRowsAndCols(); + return ReturnOrAwait.runAgain(); + } + if (partitionColumnNames.isEmpty()) { // Scenario 1: Query has atleast one window function with an OVER() clause without a PARTITION BY. if (inputChannel.canRead()) { final Frame frame = inputChannel.read(); convertRowFrameToRowsAndColumns(frame); return ReturnOrAwait.runAgain(); - } else if (inputChannel.isFinished()) { - runAllOpsOnMultipleRac(frameRowsAndCols); - return ReturnOrAwait.returnObject(Unit.instance()); - } else { - return ReturnOrAwait.awaitAll(inputChannels().size()); } + + if (inputChannel.isFinished()) { + // If no rows are flushed yet, process all rows. + if (rowId.get() == 0) { + runAllOpsOnMultipleRac(frameRowsAndCols); + } + + // If there are still rows pending after operations, run again. + if (frameHasRowsPendingFlush()) { + return ReturnOrAwait.runAgain(); + } + return ReturnOrAwait.returnObject(Unit.instance()); + } + return ReturnOrAwait.awaitAll(inputChannels().size()); } // Scenario 2: All window functions in the query have OVER() clause with a PARTITION BY if (frameCursor == null || frameCursor.isDone()) { if (readableInputs.isEmpty()) { return ReturnOrAwait.awaitAll(1); - } else if (inputChannel.canRead()) { + } + + if (inputChannel.canRead()) { final Frame frame = inputChannel.read(); frameCursor = FrameProcessors.makeCursor(frame, frameReader); makeRowSupplierFromFrameCursor(); } else if (inputChannel.isFinished()) { - // Handle any remaining data. - lastPartitionIndex = rowsToProcess.size() - 1; - processRowsUpToLastPartition(); + // If we have some rows pending processing, process them. + // We run it again as it's possible that frame writer's capacity got reached and some output rows are + // pending flush to the output channel. + if (!rowsToProcess.isEmpty()) { + lastPartitionIndex = rowsToProcess.size() - 1; + processRowsUpToLastPartition(); + return ReturnOrAwait.runAgain(); + } return ReturnOrAwait.returnObject(Unit.instance()); } else { return ReturnOrAwait.runAgain(); @@ -313,41 +336,30 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor public void completed() { try { - // resultRowsAndCols has reference to frameRowsAndCols - // due to the chain of calls across the ops - // so we can clear after writing to output - flushAllRowsAndCols(resultRowAndCols); - frameRowsAndCols.clear(); - + flushAllRowsAndCols(); } catch (IOException e) { throw new RuntimeException(e); } - finally { - frameRowsAndCols.clear(); - resultRowAndCols.clear(); - } } }); } /** - * @param resultRowAndCols Flush the list of {@link RowsAndColumns} to a frame + * Flushes {@link #resultRowAndCols} to the frame starting from {@link #rowId}, upto the frame writer's capacity. * @throws IOException */ - private void flushAllRowsAndCols(ArrayList resultRowAndCols) throws IOException + private void flushAllRowsAndCols() throws IOException { RowsAndColumns rac = new ConcatRowsAndColumns(resultRowAndCols); - AtomicInteger rowId = new AtomicInteger(0); - createFrameWriterIfNeeded(rac, rowId); - writeRacToFrame(rac, rowId); + createFrameWriterIfNeeded(rac); + writeRacToFrame(rac); } /** * @param rac The frame writer to write this {@link RowsAndColumns} object - * @param rowId RowId to get the column selector factory from the {@link RowsAndColumns} object */ - private void createFrameWriterIfNeeded(RowsAndColumns rac, AtomicInteger rowId) + private void createFrameWriterIfNeeded(RowsAndColumns rac) { if (frameWriter == null) { final ColumnSelectorFactoryMaker csfm = ColumnSelectorFactoryMaker.fromRAC(rac); @@ -355,32 +367,38 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor final ColumnSelectorFactory frameWriterColumnSelectorFactoryWithVirtualColumns = frameWriterVirtualColumns.wrap(frameWriterColumnSelectorFactory); frameWriter = frameWriterFactory.newFrameWriter(frameWriterColumnSelectorFactoryWithVirtualColumns); - currentAllocatorCapacity = frameWriterFactory.allocatorCapacity(); } } /** * @param rac {@link RowsAndColumns} to be written to frame - * @param rowId Counter to keep track of how many rows are added * @throws IOException */ - public void writeRacToFrame(RowsAndColumns rac, AtomicInteger rowId) throws IOException + public void writeRacToFrame(RowsAndColumns rac) throws IOException { final int numRows = rac.numRows(); - rowId.set(0); while (rowId.get() < numRows) { - final boolean didAddToFrame = frameWriter.addSelection(); - if (didAddToFrame) { + if (frameWriter.addSelection()) { + incrementBoostColumn(); rowId.incrementAndGet(); - partitionBoostVirtualColumn.setValue(partitionBoostVirtualColumn.getValue() + 1); - } else if (frameWriter.getNumRows() == 0) { - throw new FrameRowTooLargeException(currentAllocatorCapacity); - } else { + } else if (frameWriter.getNumRows() > 0) { flushFrameWriter(); - return; + createFrameWriterIfNeeded(rac); + + if (frameWriter.addSelection()) { + incrementBoostColumn(); + rowId.incrementAndGet(); + return; + } else { + throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity()); + } + } else { + throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity()); } } + flushFrameWriter(); + clearRACBuffers(); } @Override @@ -521,4 +539,28 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor )); } } + + /** + * Increments the value of the partition boosting column. It should be called once the row value has been written + * to the frame + */ + private void incrementBoostColumn() + { + partitionBoostVirtualColumn.setValue(partitionBoostVirtualColumn.getValue() + 1); + } + + /** + * @return true if frame has rows pending flush to the output channel, false otherwise. + */ + private boolean frameHasRowsPendingFlush() + { + return frameWriter != null && frameWriter.getNumRows() > 0; + } + + private void clearRACBuffers() + { + frameRowsAndCols.clear(); + resultRowAndCols.clear(); + rowId.set(0); + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java index 9d64fffe23e..e5d191bbb5c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/WindowOperatorQueryFrameProcessorTest.java @@ -66,6 +66,99 @@ import java.util.Map; public class WindowOperatorQueryFrameProcessorTest extends FrameProcessorTestBase { + private static final List> INPUT_ROWS = ImmutableList.of( + ImmutableMap.of("added", 1L, "cityName", "city1"), + ImmutableMap.of("added", 1L, "cityName", "city2"), + ImmutableMap.of("added", 2L, "cityName", "city3"), + ImmutableMap.of("added", 2L, "cityName", "city4"), + ImmutableMap.of("added", 2L, "cityName", "city5"), + ImmutableMap.of("added", 3L, "cityName", "city6"), + ImmutableMap.of("added", 3L, "cityName", "city7") + ); + + @Test + public void testFrameWriterReachingCapacity() throws IOException + { + // This test validates that all output rows are flushed to the output channel even if frame writer's + // capacity is reached, by subsequent iterations of runIncrementally. + final ReadableInput factChannel = buildWindowTestInputChannel(); + + RowSignature inputSignature = RowSignature.builder() + .add("cityName", ColumnType.STRING) + .add("added", ColumnType.LONG) + .build(); + + FrameReader frameReader = FrameReader.create(inputSignature); + + RowSignature outputSignature = RowSignature.builder() + .addAll(inputSignature) + .add("w0", ColumnType.LONG) + .build(); + + final WindowOperatorQuery query = new WindowOperatorQuery( + new QueryDataSource( + Druids.newScanQueryBuilder() + .dataSource(new TableDataSource("test")) + .intervals(new LegacySegmentSpec(Intervals.ETERNITY)) + .columns("cityName", "added") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(new HashMap<>()) + .build()), + new LegacySegmentSpec(Intervals.ETERNITY), + new HashMap<>(), + outputSignature, + ImmutableList.of( + new WindowOperatorFactory(new WindowRowNumberProcessor("w0")) + ), + ImmutableList.of() + ); + + final FrameWriterFactory frameWriterFactory = new LimitedFrameWriterFactory( + FrameWriters.makeRowBasedFrameWriterFactory( + new SingleMemoryAllocatorFactory(HeapMemoryAllocator.unlimited()), + outputSignature, + Collections.emptyList(), + false + ), + INPUT_ROWS.size() / 4 // This forces frameWriter's capacity to be reached. + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + final WindowOperatorQueryFrameProcessor processor = new WindowOperatorQueryFrameProcessor( + query, + factChannel.getChannel(), + outputChannel.writable(), + frameWriterFactory, + frameReader, + new ObjectMapper(), + ImmutableList.of( + new WindowOperatorFactory(new WindowRowNumberProcessor("w0")) + ), + inputSignature, + 100, + ImmutableList.of("added") + ); + + exec.runFully(processor, null); + + final Sequence> rowsFromProcessor = FrameTestUtil.readRowsFromFrameChannel( + outputChannel.readable(), + FrameReader.create(outputSignature) + ); + + List> outputRows = rowsFromProcessor.toList(); + Assert.assertEquals(INPUT_ROWS.size(), outputRows.size()); + + for (int i = 0; i < INPUT_ROWS.size(); i++) { + Map inputRow = INPUT_ROWS.get(i); + List outputRow = outputRows.get(i); + + Assert.assertEquals("cityName should match", inputRow.get("cityName"), outputRow.get(0)); + Assert.assertEquals("added should match", inputRow.get("added"), outputRow.get(1)); + Assert.assertEquals("row_number() should be correct", (long) i + 1, outputRow.get(2)); + } + } + @Test public void testBatchingOfPartitionByKeys_singleBatch() throws Exception { @@ -195,18 +288,7 @@ public class WindowOperatorQueryFrameProcessorTest extends FrameProcessorTestBas .add("cityName", ColumnType.STRING) .add("added", ColumnType.LONG) .build(); - - List> rows = ImmutableList.of( - ImmutableMap.of("added", 1L, "cityName", "city1"), - ImmutableMap.of("added", 1L, "cityName", "city2"), - ImmutableMap.of("added", 2L, "cityName", "city3"), - ImmutableMap.of("added", 2L, "cityName", "city4"), - ImmutableMap.of("added", 2L, "cityName", "city5"), - ImmutableMap.of("added", 3L, "cityName", "city6"), - ImmutableMap.of("added", 3L, "cityName", "city7") - ); - - return makeChannelFromRows(rows, inputSignature, Collections.emptyList()); + return makeChannelFromRows(INPUT_ROWS, inputSignature, Collections.emptyList()); } private ReadableInput makeChannelFromRows( diff --git a/extensions-core/multi-stage-query/src/test/quidem/org.apache.druid.msq.quidem.MSQQuidemTest/msq2.iq b/extensions-core/multi-stage-query/src/test/quidem/org.apache.druid.msq.quidem.MSQQuidemTest/msq2.iq new file mode 100644 index 00000000000..73000355615 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/quidem/org.apache.druid.msq.quidem.MSQQuidemTest/msq2.iq @@ -0,0 +1,54 @@ +!set plannerStrategy DECOUPLED +!use druidtest://?componentSupplier=DrillWindowQueryMSQComponentSupplier +!set outputformat mysql + +# This test validates that all output rows are flushed to the output channel even if frame writer's capacity is reached. + +select count(*) as actualNumRows +from ( + select countryName, cityName, channel, added, delta, row_number() over() as rowNumber + from wikipedia + group by countryName, cityName, channel, added, delta +); ++---------------+ +| actualNumRows | ++---------------+ +| 11631 | ++---------------+ +(1 row) + +!ok + +# Validate that all rows are outputted by window WindowOperatorQueryFrameProcessor layer for empty over() clause scenario. + +select count(*) as numRows, max(rowNumber) as maxRowNumber +from ( + select countryName, cityName, channel, added, delta, row_number() over() as rowNumber + from wikipedia + group by countryName, cityName, channel, added, delta +); ++---------+--------------+ +| numRows | maxRowNumber | ++---------+--------------+ +| 11631 | 11631 | ++---------+--------------+ +(1 row) + +!ok + +# Validate that all rows are outputted by window WindowOperatorQueryFrameProcessor layer for non-empty over() clause scenario. + +select rowNumber, count(rowNumber) as numRows +from ( + select countryName, cityName, channel, added, delta, row_number() over(partition by countryName, cityName, channel, added, delta) as rowNumber + from wikipedia + group by countryName, cityName, channel, added, delta +) group by rowNumber; ++-----------+---------+ +| rowNumber | numRows | ++-----------+---------+ +| 1 | 11631 | ++-----------+---------+ +(1 row) + +!ok