From 033ae233e8787787473b8eff17e2272db8597f5a Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Wed, 7 Sep 2022 07:02:21 -0700 Subject: [PATCH] MSQ extension: Fix over-capacity write in ScanQueryFrameProcessor. (#13036) * MSQ extension: Fix over-capacity write in ScanQueryFrameProcessor. Frame processors are meant to write only one output frame per cycle. The ScanQueryFrameProcessor would write two when reading from a channel if the input frame cursor cycled and then the output frame filled up while reading from the next frame. This patch fixes the bug, and adds a test. It also makes some adjustments to the processor code in order to make it easier to test. * Add license header. --- .../msq/querykit/BaseLeafFrameProcessor.java | 14 +- .../BaseLeafFrameProcessorFactory.java | 3 +- .../GroupByPreShuffleFrameProcessor.java | 25 +- ...roupByPreShuffleFrameProcessorFactory.java | 22 +- .../scan/ScanQueryFrameProcessor.java | 37 ++- .../scan/ScanQueryFrameProcessorFactory.java | 22 +- .../scan/ScanQueryFrameProcessorTest.java | 240 ++++++++++++++++++ 7 files changed, 299 insertions(+), 64 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java index 01bcacfd2e8..f971a4cc73f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java @@ -24,13 +24,13 @@ import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.druid.collections.ResourceHolder; -import org.apache.druid.frame.allocation.MemoryAllocator; import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessors; import org.apache.druid.frame.processor.ReturnOrAwait; import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.write.FrameWriterFactory; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Pair; import org.apache.druid.msq.input.ReadableInput; @@ -55,7 +55,7 @@ public abstract class BaseLeafFrameProcessor implements FrameProcessor private final ReadableInput baseInput; private final List inputChannels; private final ResourceHolder outputChannel; - private final ResourceHolder allocator; + private final ResourceHolder frameWriterFactoryHolder; private final BroadcastJoinHelper broadcastJoinHelper; private Function segmentMapFn; @@ -66,14 +66,14 @@ public abstract class BaseLeafFrameProcessor implements FrameProcessor final Int2ObjectMap sideChannels, final JoinableFactoryWrapper joinableFactory, final ResourceHolder outputChannel, - final ResourceHolder allocator, + final ResourceHolder frameWriterFactoryHolder, final long memoryReservedForBroadcastJoin ) { this.query = query; this.baseInput = baseInput; this.outputChannel = outputChannel; - this.allocator = allocator; + this.frameWriterFactoryHolder = frameWriterFactoryHolder; final Pair, BroadcastJoinHelper> inputChannelsAndBroadcastJoinHelper = makeInputChannelsAndBroadcastJoinHelper( @@ -120,12 +120,12 @@ public abstract class BaseLeafFrameProcessor implements FrameProcessor { // Don't close the output channel, because multiple workers write to the same channel. // The channel should be closed by the caller. - FrameProcessors.closeAll(inputChannels(), Collections.emptyList(), outputChannel, allocator); + FrameProcessors.closeAll(inputChannels(), Collections.emptyList(), outputChannel, frameWriterFactoryHolder); } - protected MemoryAllocator getAllocator() + protected FrameWriterFactory getFrameWriterFactory() { - return allocator.get(); + return frameWriterFactoryHolder.get(); } protected abstract ReturnOrAwait runWithSegment(SegmentWithDescriptor segment) throws IOException; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java index c844a11a15a..ef1275355c8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java @@ -147,8 +147,7 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa } } ), - makeLazyResourceHolder(allocatorQueueRef, ignored -> { - }), + makeLazyResourceHolder(allocatorQueueRef, ignored -> {}), stageDefinition.getSignature(), stageDefinition.getClusterBy(), frameContext diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java index af407f01442..5bcc4267b97 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java @@ -23,12 +23,9 @@ import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import org.apache.druid.collections.ResourceHolder; import org.apache.druid.frame.Frame; -import org.apache.druid.frame.FrameType; -import org.apache.druid.frame.allocation.MemoryAllocator; import org.apache.druid.frame.channel.FrameWithPartition; import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel; -import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameRowTooLargeException; import org.apache.druid.frame.processor.ReturnOrAwait; @@ -36,7 +33,6 @@ import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.segment.FrameSegment; import org.apache.druid.frame.write.FrameWriter; import org.apache.druid.frame.write.FrameWriterFactory; -import org.apache.druid.frame.write.FrameWriters; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Yielder; @@ -66,8 +62,6 @@ public class GroupByPreShuffleFrameProcessor extends BaseLeafFrameProcessor { private final GroupByQuery query; private final GroupByStrategySelector strategySelector; - private final RowSignature aggregationSignature; - private final ClusterBy clusterBy; private final ColumnSelectorFactory frameWriterColumnSelectorFactory; private final Closer closer = Closer.create(); @@ -82,10 +76,8 @@ public class GroupByPreShuffleFrameProcessor extends BaseLeafFrameProcessor final Int2ObjectMap sideChannels, final GroupByStrategySelector strategySelector, final JoinableFactoryWrapper joinableFactory, - final RowSignature aggregationSignature, - final ClusterBy clusterBy, final ResourceHolder outputChannel, - final ResourceHolder allocator, + final ResourceHolder frameWriterFactoryHolder, final long memoryReservedForBroadcastJoin ) { @@ -95,13 +87,11 @@ public class GroupByPreShuffleFrameProcessor extends BaseLeafFrameProcessor sideChannels, joinableFactory, outputChannel, - allocator, + frameWriterFactoryHolder, memoryReservedForBroadcastJoin ); this.query = query; this.strategySelector = strategySelector; - this.aggregationSignature = aggregationSignature; - this.clusterBy = clusterBy; this.frameWriterColumnSelectorFactory = RowBasedGrouperHelper.createResultRowBasedColumnSelectorFactory( query, () -> resultYielder.get(), @@ -209,16 +199,9 @@ public class GroupByPreShuffleFrameProcessor extends BaseLeafFrameProcessor private void createFrameWriterIfNeeded() { if (frameWriter == null) { - final MemoryAllocator allocator = getAllocator(); - final FrameWriterFactory frameWriterFactory = - FrameWriters.makeFrameWriterFactory( - FrameType.ROW_BASED, - allocator, - aggregationSignature, - clusterBy.getColumns() - ); + final FrameWriterFactory frameWriterFactory = getFrameWriterFactory(); frameWriter = frameWriterFactory.newFrameWriter(frameWriterColumnSelectorFactory); - currentAllocatorCapacity = allocator.capacity(); + currentAllocatorCapacity = frameWriterFactory.allocatorCapacity(); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java index b73f00efa2e..63ad3cf8909 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java @@ -25,13 +25,17 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.frame.FrameType; import org.apache.druid.frame.allocation.MemoryAllocator; import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.write.FrameWriters; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory; +import org.apache.druid.msq.querykit.LazyResourceHolder; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.join.JoinableFactoryWrapper; @@ -57,8 +61,8 @@ public class GroupByPreShuffleFrameProcessorFactory extends BaseLeafFrameProcess protected FrameProcessor makeProcessor( final ReadableInput baseInput, final Int2ObjectMap sideChannels, - final ResourceHolder outputChannelSupplier, - final ResourceHolder allocatorSupplier, + final ResourceHolder outputChannelHolder, + final ResourceHolder allocatorHolder, final RowSignature signature, final ClusterBy clusterBy, final FrameContext frameContext @@ -70,10 +74,16 @@ public class GroupByPreShuffleFrameProcessorFactory extends BaseLeafFrameProcess sideChannels, frameContext.groupByStrategySelector(), new JoinableFactoryWrapper(frameContext.joinableFactory()), - signature, - clusterBy, - outputChannelSupplier, - allocatorSupplier, + outputChannelHolder, + new LazyResourceHolder<>(() -> Pair.of( + FrameWriters.makeFrameWriterFactory( + FrameType.ROW_BASED, + allocatorHolder.get(), + signature, + clusterBy.getColumns() + ), + allocatorHolder + )), frameContext.memoryParameters().getBroadcastJoinMemory() ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java index 1b8d21baae1..307d274c73b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java @@ -25,12 +25,9 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.druid.collections.ResourceHolder; import org.apache.druid.frame.Frame; -import org.apache.druid.frame.FrameType; -import org.apache.druid.frame.allocation.MemoryAllocator; import org.apache.druid.frame.channel.FrameWithPartition; import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel; -import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameRowTooLargeException; import org.apache.druid.frame.processor.ReturnOrAwait; @@ -39,7 +36,6 @@ import org.apache.druid.frame.segment.FrameSegment; import org.apache.druid.frame.util.SettableLongVirtualColumn; import org.apache.druid.frame.write.FrameWriter; import org.apache.druid.frame.write.FrameWriterFactory; -import org.apache.druid.frame.write.FrameWriters; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.granularity.Granularities; @@ -60,7 +56,6 @@ import org.apache.druid.segment.Cursor; import org.apache.druid.segment.StorageAdapter; import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.VirtualColumns; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.timeline.SegmentId; @@ -79,8 +74,6 @@ import java.util.concurrent.atomic.AtomicLong; public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor { private final ScanQuery query; - private final RowSignature signature; - private final ClusterBy clusterBy; private final AtomicLong runningCountForLimit; private final SettableLongVirtualColumn partitionBoostVirtualColumn; private final VirtualColumns frameWriterVirtualColumns; @@ -93,13 +86,11 @@ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor public ScanQueryFrameProcessor( final ScanQuery query, - final RowSignature signature, - final ClusterBy clusterBy, final ReadableInput baseInput, final Int2ObjectMap sideChannels, final JoinableFactoryWrapper joinableFactory, final ResourceHolder outputChannel, - final ResourceHolder allocator, + final ResourceHolder frameWriterFactoryHolder, @Nullable final AtomicLong runningCountForLimit, final long memoryReservedForBroadcastJoin ) @@ -110,12 +101,10 @@ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor sideChannels, joinableFactory, outputChannel, - allocator, + frameWriterFactoryHolder, memoryReservedForBroadcastJoin ); this.query = query; - this.signature = signature; - this.clusterBy = clusterBy; this.runningCountForLimit = runningCountForLimit; this.partitionBoostVirtualColumn = new SettableLongVirtualColumn(QueryKitUtils.PARTITION_BOOST_COLUMN); @@ -174,7 +163,8 @@ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor cursorYielder.close(); return ReturnOrAwait.returnObject(rowsOutput); } else { - setNextCursor(cursorYielder.get()); + final long rowsFlushed = setNextCursor(cursorYielder.get()); + assert rowsFlushed == 0; // There's only ever one cursor when running with a segment closer.register(cursorYielder); } } @@ -201,9 +191,9 @@ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor if (cursor == null || cursor.isDone()) { if (inputChannel.canRead()) { final Frame frame = inputChannel.read(); - final FrameSegment frameSegment = new FrameSegment(frame, inputFrameReader, SegmentId.dummy("x")); + final FrameSegment frameSegment = new FrameSegment(frame, inputFrameReader, SegmentId.dummy("scan")); - setNextCursor( + final long rowsFlushed = setNextCursor( Iterables.getOnlyElement( makeCursors( query.withQuerySegmentSpec(new MultipleIntervalSegmentSpec(Intervals.ONLY_ETERNITY)), @@ -211,6 +201,10 @@ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor ).toList() ) ); + + if (rowsFlushed > 0) { + return ReturnOrAwait.runAgain(); + } } else if (inputChannel.isFinished()) { flushFrameWriter(); return ReturnOrAwait.returnObject(rowsOutput); @@ -256,13 +250,11 @@ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor private void createFrameWriterIfNeeded() { if (frameWriter == null) { - final MemoryAllocator allocator = getAllocator(); - final FrameWriterFactory frameWriterFactory = - FrameWriters.makeFrameWriterFactory(FrameType.ROW_BASED, allocator, signature, clusterBy.getColumns()); + final FrameWriterFactory frameWriterFactory = getFrameWriterFactory(); final ColumnSelectorFactory frameWriterColumnSelectorFactory = frameWriterVirtualColumns.wrap(cursor.getColumnSelectorFactory()); frameWriter = frameWriterFactory.newFrameWriter(frameWriterColumnSelectorFactory); - currentAllocatorCapacity = allocator.capacity(); + currentAllocatorCapacity = frameWriterFactory.allocatorCapacity(); } } @@ -285,10 +277,11 @@ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor } } - private void setNextCursor(final Cursor cursor) throws IOException + private long setNextCursor(final Cursor cursor) throws IOException { - flushFrameWriter(); + final long rowsFlushed = flushFrameWriter(); this.cursor = cursor; + return rowsFlushed; } private static Sequence makeCursors(final ScanQuery query, final StorageAdapter adapter) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java index 08b41ddfc3c..2a948fd4562 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java @@ -25,13 +25,17 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.frame.FrameType; import org.apache.druid.frame.allocation.MemoryAllocator; import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.write.FrameWriters; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory; +import org.apache.druid.msq.querykit.LazyResourceHolder; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.join.JoinableFactoryWrapper; @@ -73,8 +77,8 @@ public class ScanQueryFrameProcessorFactory extends BaseLeafFrameProcessorFactor protected FrameProcessor makeProcessor( ReadableInput baseInput, Int2ObjectMap sideChannels, - ResourceHolder outputChannelSupplier, - ResourceHolder allocatorSupplier, + ResourceHolder outputChannelHolder, + ResourceHolder allocatorHolder, RowSignature signature, ClusterBy clusterBy, FrameContext frameContext @@ -82,13 +86,19 @@ public class ScanQueryFrameProcessorFactory extends BaseLeafFrameProcessorFactor { return new ScanQueryFrameProcessor( query, - signature, - clusterBy, baseInput, sideChannels, new JoinableFactoryWrapper(frameContext.joinableFactory()), - outputChannelSupplier, - allocatorSupplier, + outputChannelHolder, + new LazyResourceHolder<>(() -> Pair.of( + FrameWriters.makeFrameWriterFactory( + FrameType.ROW_BASED, + allocatorHolder.get(), + signature, + clusterBy.getColumns() + ), + allocatorHolder + )), runningCountForLimit, frameContext.memoryParameters().getBroadcastJoinMemory() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java new file mode 100644 index 00000000000..2ea2958c736 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.scan; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import it.unimi.dsi.fastutil.ints.Int2ObjectMaps; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.FrameType; +import org.apache.druid.frame.allocation.ArenaMemoryAllocator; +import org.apache.druid.frame.allocation.HeapMemoryAllocator; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.processor.FrameProcessorExecutor; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.testutil.FrameSequenceBuilder; +import org.apache.druid.frame.testutil.FrameTestUtil; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.frame.write.FrameWriters; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.StagePartition; +import org.apache.druid.msq.querykit.LazyResourceHolder; +import org.apache.druid.query.Druids; +import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.TestIndex; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter; +import org.apache.druid.segment.join.JoinableFactoryWrapper; +import org.apache.druid.segment.join.NoopJoinableFactory; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class ScanQueryFrameProcessorTest extends InitializedNullHandlingTest +{ + private FrameProcessorExecutor exec; + + @Before + public void setUp() + { + exec = new FrameProcessorExecutor(MoreExecutors.listeningDecorator(Execs.singleThreaded("test-exec"))); + } + + @After + public void tearDown() throws Exception + { + exec.getExecutorService().shutdownNow(); + exec.getExecutorService().awaitTermination(10, TimeUnit.MINUTES); + } + + @Test + public void test_runWithInputChannel() throws Exception + { + final IncrementalIndexStorageAdapter adapter = + new IncrementalIndexStorageAdapter(TestIndex.getIncrementalTestIndex()); + + final FrameSequenceBuilder frameSequenceBuilder = + FrameSequenceBuilder.fromAdapter(adapter) + .maxRowsPerFrame(5) + .frameType(FrameType.ROW_BASED) + .allocator(ArenaMemoryAllocator.createOnHeap(100_000)); + + final RowSignature signature = frameSequenceBuilder.signature(); + final List frames = frameSequenceBuilder.frames().toList(); + final BlockingQueueFrameChannel inputChannel = new BlockingQueueFrameChannel(frames.size()); + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + try (final WritableFrameChannel writableInputChannel = inputChannel.writable()) { + for (final Frame frame : frames) { + writableInputChannel.write(frame); + } + } + + final ScanQuery query = + Druids.newScanQueryBuilder() + .dataSource("test") + .intervals(new MultipleIntervalSegmentSpec(Intervals.ONLY_ETERNITY)) + .columns(adapter.getRowSignature().getColumnNames()) + .legacy(false) + .build(); + + final StagePartition stagePartition = new StagePartition(new StageId("query", 0), 0); + + // Limit output frames to 1 row to ensure we test edge cases + final FrameWriterFactory frameWriterFactory = limitedFrameWriterFactory( + FrameWriters.makeFrameWriterFactory( + FrameType.ROW_BASED, + HeapMemoryAllocator.unlimited(), + signature, + Collections.emptyList() + ), + 1 + ); + + final ScanQueryFrameProcessor processor = new ScanQueryFrameProcessor( + query, + ReadableInput.channel(inputChannel.readable(), FrameReader.create(signature), stagePartition), + Int2ObjectMaps.emptyMap(), + new JoinableFactoryWrapper(NoopJoinableFactory.INSTANCE), + new ResourceHolder() + { + @Override + public WritableFrameChannel get() + { + return outputChannel.writable(); + } + + @Override + public void close() + { + try { + outputChannel.writable().close(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + }, + new LazyResourceHolder<>(() -> Pair.of(frameWriterFactory, () -> {})), + null, + 0L + ); + + ListenableFuture retVal = exec.runFully(processor, null); + + final Sequence> rowsFromProcessor = FrameTestUtil.readRowsFromFrameChannel( + outputChannel.readable(), + FrameReader.create(signature) + ); + + FrameTestUtil.assertRowsEqual( + FrameTestUtil.readRowsFromAdapter(adapter, signature, false), + rowsFromProcessor + ); + + Assert.assertEquals(adapter.getNumRows(), (long) retVal.get()); + } + + /** + * Wraps a {@link FrameWriterFactory}, creating a new factory that returns {@link FrameWriter} which write + * a limited number of rows. + */ + private static FrameWriterFactory limitedFrameWriterFactory(final FrameWriterFactory baseFactory, final int rowLimit) + { + return new FrameWriterFactory() + { + @Override + public FrameWriter newFrameWriter(ColumnSelectorFactory columnSelectorFactory) + { + return new LimitedFrameWriter(baseFactory.newFrameWriter(columnSelectorFactory), rowLimit); + } + + @Override + public long allocatorCapacity() + { + return baseFactory.allocatorCapacity(); + } + }; + } + + private static class LimitedFrameWriter implements FrameWriter + { + private final FrameWriter baseWriter; + private final int rowLimit; + + public LimitedFrameWriter(FrameWriter baseWriter, int rowLimit) + { + this.baseWriter = baseWriter; + this.rowLimit = rowLimit; + } + + @Override + public boolean addSelection() + { + if (baseWriter.getNumRows() >= rowLimit) { + return false; + } else { + return baseWriter.addSelection(); + } + } + + @Override + public int getNumRows() + { + return baseWriter.getNumRows(); + } + + @Override + public long getTotalSize() + { + return baseWriter.getTotalSize(); + } + + @Override + public long writeTo(WritableMemory memory, long position) + { + return baseWriter.writeTo(memory, position); + } + + @Override + public void close() + { + baseWriter.close(); + } + } +}