diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java index c9e598c6188..9069794222f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java @@ -69,12 +69,6 @@ public class Limits */ public static final int MAX_KERNEL_MANIPULATION_QUEUE_SIZE = 100_000; - /** - * Maximum number of bytes buffered for each side of a - * {@link org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessor}, not counting the most recent frame read. - */ - public static final int MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN = 10_000_000; - /** * Maximum relaunches across all workers. */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index f64e4dbd0ef..4bddb949f07 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -137,29 +137,33 @@ public class WorkerMemoryParameters * we use a value somewhat lower than 0.5. */ static final double BROADCAST_JOIN_MEMORY_FRACTION = 0.3; + + /** + * Fraction of free memory per bundle that can be used by + * {@link org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessor} to buffer frames in its trackers. + */ + static final double SORT_MERGE_JOIN_MEMORY_FRACTION = 0.9; + /** * In case {@link NotEnoughMemoryFault} is thrown, a fixed estimation overhead is added when estimating total memory required for the process. */ private static final long BUFFER_BYTES_FOR_ESTIMATION = 1000; + private final long processorBundleMemory; private final int superSorterMaxActiveProcessors; private final int superSorterMaxChannelsPerProcessor; - private final long appenderatorMemory; - private final long broadcastJoinMemory; private final int partitionStatisticsMaxRetainedBytes; WorkerMemoryParameters( + final long processorBundleMemory, final int superSorterMaxActiveProcessors, final int superSorterMaxChannelsPerProcessor, - final long appenderatorMemory, - final long broadcastJoinMemory, final int partitionStatisticsMaxRetainedBytes ) { + this.processorBundleMemory = processorBundleMemory; this.superSorterMaxActiveProcessors = superSorterMaxActiveProcessors; this.superSorterMaxChannelsPerProcessor = superSorterMaxChannelsPerProcessor; - this.appenderatorMemory = appenderatorMemory; - this.broadcastJoinMemory = broadcastJoinMemory; this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes; } @@ -344,10 +348,9 @@ public class WorkerMemoryParameters } return new WorkerMemoryParameters( + bundleMemoryForProcessing, superSorterMaxActiveProcessors, superSorterMaxChannelsPerProcessor, - (long) (bundleMemoryForProcessing * APPENDERATOR_MEMORY_FRACTION), - (long) (bundleMemoryForProcessing * BROADCAST_JOIN_MEMORY_FRACTION), Ints.checkedCast(workerMemory) // 100% of worker memory is devoted to partition statistics ); } @@ -365,13 +368,13 @@ public class WorkerMemoryParameters public long getAppenderatorMaxBytesInMemory() { // Half for indexing, half for merging. - return Math.max(1, appenderatorMemory / 2); + return Math.max(1, getAppenderatorMemory() / 2); } public int getAppenderatorMaxColumnsToMerge() { // Half for indexing, half for merging. - return Ints.checkedCast(Math.max(2, appenderatorMemory / 2 / APPENDERATOR_MERGE_ROUGH_MEMORY_PER_COLUMN)); + return Ints.checkedCast(Math.max(2, getAppenderatorMemory() / 2 / APPENDERATOR_MERGE_ROUGH_MEMORY_PER_COLUMN)); } public int getStandardFrameSize() @@ -386,7 +389,12 @@ public class WorkerMemoryParameters public long getBroadcastJoinMemory() { - return broadcastJoinMemory; + return (long) (processorBundleMemory * BROADCAST_JOIN_MEMORY_FRACTION); + } + + public long getSortMergeJoinMemory() + { + return (long) (processorBundleMemory * SORT_MERGE_JOIN_MEMORY_FRACTION); } public int getPartitionStatisticsMaxRetainedBytes() @@ -394,6 +402,14 @@ public class WorkerMemoryParameters return partitionStatisticsMaxRetainedBytes; } + /** + * Amount of memory to devote to {@link org.apache.druid.segment.realtime.appenderator.Appenderator}. + */ + private long getAppenderatorMemory() + { + return (long) (processorBundleMemory * APPENDERATOR_MEMORY_FRACTION); + } + @Override public boolean equals(Object o) { @@ -404,10 +420,9 @@ public class WorkerMemoryParameters return false; } WorkerMemoryParameters that = (WorkerMemoryParameters) o; - return superSorterMaxActiveProcessors == that.superSorterMaxActiveProcessors + return processorBundleMemory == that.processorBundleMemory + && superSorterMaxActiveProcessors == that.superSorterMaxActiveProcessors && superSorterMaxChannelsPerProcessor == that.superSorterMaxChannelsPerProcessor - && appenderatorMemory == that.appenderatorMemory - && broadcastJoinMemory == that.broadcastJoinMemory && partitionStatisticsMaxRetainedBytes == that.partitionStatisticsMaxRetainedBytes; } @@ -415,10 +430,9 @@ public class WorkerMemoryParameters public int hashCode() { return Objects.hash( + processorBundleMemory, superSorterMaxActiveProcessors, superSorterMaxChannelsPerProcessor, - appenderatorMemory, - broadcastJoinMemory, partitionStatisticsMaxRetainedBytes ); } @@ -427,10 +441,9 @@ public class WorkerMemoryParameters public String toString() { return "WorkerMemoryParameters{" + - "superSorterMaxActiveProcessors=" + superSorterMaxActiveProcessors + + "processorBundleMemory=" + processorBundleMemory + + ", superSorterMaxActiveProcessors=" + superSorterMaxActiveProcessors + ", superSorterMaxChannelsPerProcessor=" + superSorterMaxChannelsPerProcessor + - ", appenderatorMemory=" + appenderatorMemory + - ", broadcastJoinMemory=" + broadcastJoinMemory + ", partitionStatisticsMaxRetainedBytes=" + partitionStatisticsMaxRetainedBytes + '}'; } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java index 21fa363af89..60d355579b6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java @@ -44,10 +44,12 @@ public class TooManyRowsWithSameKeyFault extends BaseMSQFault { super( CODE, - "Too many rows with the same key during sort-merge join (bytes buffered = %,d; limit = %,d). Key: %s", + "Too many rows with the same key[%s] during sort-merge join (bytes buffered[%,d], limit[%,d]). " + + "Try increasing heap memory available to workers, " + + "or adjusting your query to process fewer rows with this key.", + key, numBytes, - maxBytes, - key + maxBytes ); this.key = key; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java index 36dc52c5cee..d9e7bc6deec 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java @@ -27,6 +27,7 @@ import org.apache.druid.frame.Frame; import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.processor.FrameProcessors; import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.indexing.error.BroadcastTablesTooLargeFault; import org.apache.druid.msq.indexing.error.MSQException; import org.apache.druid.query.DataSource; @@ -58,7 +59,7 @@ public class BroadcastJoinHelper * @param channels list of input channels * @param channelReaders list of input channel readers; corresponds one-to-one with "channels" * @param memoryReservedForBroadcastJoin total bytes of frames we are permitted to use; derived from - * {@link org.apache.druid.msq.exec.WorkerMemoryParameters#broadcastJoinMemory} + * {@link WorkerMemoryParameters#getBroadcastJoinMemory()} */ public BroadcastJoinHelper( final Int2IntMap inputNumberToProcessorChannelMap, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java index 2c454e1d45c..fdc80560f29 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java @@ -41,7 +41,6 @@ import org.apache.druid.frame.segment.FrameCursor; import org.apache.druid.frame.write.FrameWriter; import org.apache.druid.frame.write.FrameWriterFactory; import org.apache.druid.java.util.common.ISE; -import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.indexing.error.MSQException; import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; import org.apache.druid.msq.input.ReadableInput; @@ -122,6 +121,7 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor private final String rightPrefix; private final JoinType joinType; private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private final long maxBufferedBytes; private FrameWriter frameWriter = null; // Used by runIncrementally to defer certain logic to the next run. @@ -137,7 +137,8 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor FrameWriterFactory frameWriterFactory, String rightPrefix, List> keyColumns, - JoinType joinType + JoinType joinType, + long maxBufferedBytes ) { this.inputChannels = ImmutableList.of(left.getChannel(), right.getChannel()); @@ -146,9 +147,10 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor this.rightPrefix = rightPrefix; this.joinType = joinType; this.trackers = ImmutableList.of( - new Tracker(left, keyColumns.get(LEFT)), - new Tracker(right, keyColumns.get(RIGHT)) + new Tracker(left, keyColumns.get(LEFT), maxBufferedBytes), + new Tracker(right, keyColumns.get(RIGHT), maxBufferedBytes) ); + this.maxBufferedBytes = maxBufferedBytes; } @Override @@ -166,10 +168,10 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor @Override public ReturnOrAwait runIncrementally(IntSet readableInputs) throws IOException { - // Fetch enough frames such that each tracker has one readable row. + // Fetch enough frames such that each tracker has one readable row (or is done). for (int i = 0; i < inputChannels.size(); i++) { final Tracker tracker = trackers.get(i); - if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + if (tracker.needsMoreDataForCurrentCursor() && !pushNextFrame(i)) { return nextAwait(); } } @@ -178,8 +180,8 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor startNewFrameIfNeeded(); while (!allTrackersAreAtEnd() - && !trackers.get(LEFT).needsMoreData() - && !trackers.get(RIGHT).needsMoreData()) { + && !trackers.get(LEFT).needsMoreDataForCurrentCursor() + && !trackers.get(RIGHT).needsMoreDataForCurrentCursor()) { // Algorithm can proceed: not all trackers are at the end of their streams, and no tracker needs more data to // read the current cursor or move it forward. if (nextIterationRunnable != null) { @@ -192,21 +194,12 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor // Two rows match if the keys compare equal _and_ neither key has a null component. (x JOIN y ON x.a = y.a does // not match rows where "x.a" is null.) - final boolean match = markCmp == 0 && trackers.get(LEFT).hasCompletelyNonNullMark(); + final boolean marksMatch = markCmp == 0 && trackers.get(LEFT).hasCompletelyNonNullMark(); - // If marked keys are equal on both sides ("match"), at least one side must have a complete set of rows - // for the marked key. - if (match && trackerWithCompleteSetForCurrentKey < 0) { - for (int i = 0; i < inputChannels.size(); i++) { - final Tracker tracker = trackers.get(i); - - // Fetch up to one frame from each tracker, to check if that tracker has a complete set. - // Can't fetch more than one frame, because channels are only guaranteed to have one frame per run. - if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && tracker.hasCompleteSetForMark())) { - trackerWithCompleteSetForCurrentKey = i; - break; - } - } + // If marked keys are equal on both sides ("marksMatch"), at least one side needs to have a complete set of rows + // for the marked key. Check if this is true, otherwise call nextAwait to read more data. + if (marksMatch && trackerWithCompleteSetForCurrentKey < 0) { + updateTrackerWithCompleteSetForCurrentKey(); if (trackerWithCompleteSetForCurrentKey < 0) { // Algorithm cannot proceed; fetch more frames on the next run. @@ -214,73 +207,13 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor } } - if (match || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && joinType.isRighty())) { - // Emit row, if there's room in the current frameWriter. - joinColumnSelectorFactory.cmp = markCmp; - joinColumnSelectorFactory.match = match; - - if (!frameWriter.addSelection()) { - if (frameWriter.getNumRows() > 0) { - // Out of space in the current frame. Run again without moving cursors. - flushCurrentFrame(); - return ReturnOrAwait.runAgain(); - } else { - throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity()); - } - } + // Emit row if there was a match. + if (!emitRowIfNeeded(markCmp, marksMatch)) { + return ReturnOrAwait.runAgain(); } // Advance one or both trackers. - if (match) { - // Matching keys. First advance the tracker with the complete set. - final Tracker tracker = trackers.get(trackerWithCompleteSetForCurrentKey); - final Tracker otherTracker = trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT); - - tracker.advance(); - if (!tracker.isCurrentSameKeyAsMark()) { - // Reached end of complete set. Advance the other tracker. - otherTracker.advance(); - - // On next iteration (when we're sure to have data) either rewind the complete-set tracker, or update marks - // of both, as appropriate. - onNextIteration(() -> { - if (otherTracker.isCurrentSameKeyAsMark()) { - otherTracker.markCurrent(); // Set mark to enable cleanup of old frames. - tracker.rewindToMark(); - } else { - // Reached end of the other side too. Advance marks on both trackers. - tracker.markCurrent(); - otherTracker.markCurrent(); - trackerWithCompleteSetForCurrentKey = -1; - } - }); - } - } else { - final int trackerToAdvance; - - if (markCmp < 0) { - trackerToAdvance = LEFT; - } else if (markCmp > 0) { - trackerToAdvance = RIGHT; - } else { - // Key is null on both sides. Note that there is a preference for running through the left side first - // on a FULL join. It doesn't really matter which side we run through first, but we do need to be consistent - // for the benefit of the logic in "shouldEmitColumnValue". - trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT; - } - - final Tracker tracker = trackers.get(trackerToAdvance); - - tracker.advance(); - - // On next iteration (when we're sure to have data), update mark if the key changed. - onNextIteration(() -> { - if (!tracker.isCurrentSameKeyAsMark()) { - tracker.markCurrent(); - trackerWithCompleteSetForCurrentKey = -1; - } - }); - } + advanceTrackersAfterEmittingRow(markCmp, marksMatch); } if (allTrackersAreAtEnd()) { @@ -299,8 +232,152 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor } /** - * Returns a {@link ReturnOrAwait#awaitAll} for the channel numbers that need more data and have not yet hit their - * buffered-bytes limit, {@link Limits#MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN}. + * Set {@link #trackerWithCompleteSetForCurrentKey} to the lowest-numbered {@link Tracker} that has a complete + * set of rows available for its mark. + */ + private void updateTrackerWithCompleteSetForCurrentKey() + { + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + + // Fetch up to one frame from each tracker, to check if that tracker has a complete set. + // Can't fetch more than one frame, because channels are only guaranteed to have one frame per run. + if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && tracker.hasCompleteSetForMark())) { + trackerWithCompleteSetForCurrentKey = i; + return; + } + } + + trackerWithCompleteSetForCurrentKey = -1; + } + + /** + * Emits a joined row based on the current state of all trackers. + * + * @param markCmp result of {@link #compareMarks()} + * @param marksMatch whether the marks actually matched, taking nulls into account + * + * @return true if cursors should be advanced, false if we should run again without moving cursors + */ + private boolean emitRowIfNeeded(final int markCmp, final boolean marksMatch) throws IOException + { + if (marksMatch || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && joinType.isRighty())) { + // Emit row, if there's room in the current frameWriter. + joinColumnSelectorFactory.cmp = markCmp; + joinColumnSelectorFactory.match = marksMatch; + + if (!frameWriter.addSelection()) { + if (frameWriter.getNumRows() > 0) { + // Out of space in the current frame. Run again without moving cursors. + flushCurrentFrame(); + return false; + } else { + throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity()); + } + } + } + + return true; + } + + /** + * Advance one or both trackers after emitting a row. + * + * @param markCmp result of {@link #compareMarks()} + * @param marksMatch whether the marks actually matched, taking nulls into account + */ + private void advanceTrackersAfterEmittingRow(final int markCmp, final boolean marksMatch) + { + if (marksMatch) { + // Matching keys. First advance the tracker with the complete set. + final Tracker completeSetTracker = trackers.get(trackerWithCompleteSetForCurrentKey); + final Tracker otherTracker = trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT); + + completeSetTracker.advance(); + if (!completeSetTracker.isCurrentSameKeyAsMark()) { + // Reached end of complete set. Advance the other tracker. + otherTracker.advance(); + + // On next iteration (when we're sure to have data) either rewind the complete-set tracker, or update marks + // of both, as appropriate. + onNextIteration(() -> { + if (otherTracker.isCurrentSameKeyAsMark()) { + completeSetTracker.rewindToMark(); + } else { + // Reached end of the other side too. Advance marks on both trackers. + completeSetTracker.markCurrent(); + trackerWithCompleteSetForCurrentKey = -1; + } + + // Always update mark of the other tracker, to enable cleanup of old frames. It doesn't ever need to + // be rewound. + otherTracker.markCurrent(); + }); + } + } else { + // Keys don't match. Advance based on what kind of join this is. + final int trackerToAdvance; + final boolean skipMarkedKey; + + if (markCmp < 0) { + trackerToAdvance = LEFT; + } else if (markCmp > 0) { + trackerToAdvance = RIGHT; + } else { + // Key is null on both sides. Note that there is a preference for running through the left side first + // on a FULL join. It doesn't really matter which side we run through first, but we do need to be consistent + // for the benefit of the logic in "shouldEmitColumnValue". + trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT; + } + + // Skip marked key entirely if we're on the "off" side of the join. (i.e., right side of a LEFT join.) + // Note that for FULL joins, entire keys are never skipped, because they are both lefty and righty. + if (trackerToAdvance == LEFT) { + skipMarkedKey = !joinType.isLefty(); + } else { + skipMarkedKey = !joinType.isRighty(); + } + + final Tracker tracker = trackers.get(trackerToAdvance); + + // Advance past marked key, or as far as we can. + boolean didKeyChange = false; + + do { + // Always advance a single row. If we're in "skipMarkedKey" mode, then we'll loop through later and + // potentially skip multiple rows with the same marked key. + tracker.advance(); + + if (tracker.isAtEndOfPushedData()) { + break; + } + + didKeyChange = !tracker.isCurrentSameKeyAsMark(); + + // Always update mark, even if key hasn't changed, to enable cleanup of old frames. + tracker.markCurrent(); + } while (skipMarkedKey && !didKeyChange); + + if (didKeyChange) { + trackerWithCompleteSetForCurrentKey = -1; + } else if (tracker.isAtEndOfPushedData()) { + // Not clear if we reached a new key or not. + // So, on next iteration (when we're sure to have data), check if we've moved on to a new key. + onNextIteration(() -> { + if (!tracker.isCurrentSameKeyAsMark()) { + trackerWithCompleteSetForCurrentKey = -1; + } + + // Always update mark, even if key hasn't changed, to enable cleanup of old frames. + tracker.markCurrent(); + }); + } + } + } + + /** + * Returns a {@link ReturnOrAwait#awaitAll} for channels where {@link Tracker#needsMoreDataForCurrentCursor()} + * and {@link Tracker#canBufferMoreFrames()}. * * If all channels have hit their limit, throws {@link MSQException} with {@link TooManyRowsWithSameKeyFault}. */ @@ -309,10 +386,11 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor final IntSet awaitSet = new IntOpenHashSet(); int trackerAtLimit = -1; + // Add all trackers that "needsMoreData" to awaitSet. for (int i = 0; i < inputChannels.size(); i++) { final Tracker tracker = trackers.get(i); - if (tracker.needsMoreData()) { - if (tracker.totalBytesBuffered() < Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { + if (tracker.needsMoreDataForCurrentCursor()) { + if (tracker.canBufferMoreFrames()) { awaitSet.add(i); } else if (trackerAtLimit < 0) { trackerAtLimit = i; @@ -320,19 +398,31 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor } } - if (awaitSet.isEmpty() && trackerAtLimit > 0) { + if (awaitSet.isEmpty()) { + // No tracker reported that it "needsMoreData" to read the current cursor. However, we may still need to read + // more data to have a complete set for the current mark. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (!tracker.hasCompleteSetForMark()) { + if (tracker.canBufferMoreFrames()) { + awaitSet.add(i); + } else if (trackerAtLimit < 0) { + trackerAtLimit = i; + } + } + } + } + + if (awaitSet.isEmpty() && trackerAtLimit >= 0) { // All trackers that need more data are at their max buffered bytes limit. Generate a nice exception. final Tracker tracker = trackers.get(trackerAtLimit); - if (tracker.totalBytesBuffered() > Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { - // Generate a nice exception. - throw new MSQException( - new TooManyRowsWithSameKeyFault( - tracker.readMarkKey(), - tracker.totalBytesBuffered(), - Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN - ) - ); - } + throw new MSQException( + new TooManyRowsWithSameKeyFault( + tracker.readMarkKey(), + tracker.totalBytesBuffered(), + maxBufferedBytes + ) + ); } return ReturnOrAwait.awaitAll(awaitSet); @@ -353,7 +443,13 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor } /** - * Compares the marked rows of the two {@link #trackers}. + * Compares the marked rows of the two {@link #trackers}. This method returns 0 if both sides are null, even + * though this is not considered a match by join semantics. Therefore, it is important to also check + * {@link Tracker#hasCompletelyNonNullMark()}. + * + * @return negative if {@link #LEFT} key is earlier, positive if {@link #RIGHT} key is earlier, zero if the keys + * are the same. Returns zero even if a key component is null, even though this is not considered a match by + * join semantics. * * @throws IllegalStateException if either tracker does not have a marked row and is not completely done */ @@ -394,6 +490,8 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor } else if (channel.isFinished()) { tracker.push(null); return true; + } else if (!tracker.canBufferMoreFrames()) { + return false; } else { final Frame frame = channel.read(); @@ -450,6 +548,7 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor private final List holders = new ArrayList<>(); private final ReadableInput input; private final List keyColumns; + private final long maxBytesBuffered; // markFrame and markRow are the first frame and row with the current key. private int markFrame = -1; @@ -461,10 +560,11 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor // done indicates that no more data is available in the channel. private boolean done; - public Tracker(ReadableInput input, List keyColumns) + public Tracker(ReadableInput input, List keyColumns, long maxBytesBuffered) { this.input = input; this.keyColumns = keyColumns; + this.maxBytesBuffered = maxBytesBuffered; } /** @@ -533,6 +633,16 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor return bytes; } + /** + * Whether this tracker can accept more frames without exceeding {@link #maxBufferedBytes}. Always returns true + * if the number of buffered frames is zero or one, because the join algorithm may require two frames being + * buffered. (For example, if we need to verify that the last row in a frame contains a complete set of a key.) + */ + public boolean canBufferMoreFrames() + { + return holders.size() <= 1 || totalBytesBuffered() < maxBytesBuffered; + } + /** * Cursor containing the current row. */ @@ -655,7 +765,7 @@ public class SortMergeJoinFrameProcessor implements FrameProcessor /** * Whether this tracker needs more data in order to read the current cursor location or move it forward. */ - public boolean needsMoreData() + public boolean needsMoreDataForCurrentCursor() { return !done && isAtEndOfPushedData(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java index 9aa50630929..76e05d3ce0c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java @@ -180,7 +180,8 @@ public class SortMergeJoinFrameProcessorFactory extends BaseFrameProcessorFactor stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()), rightPrefix, keyColumns, - joinType + joinType, + frameContext.memoryParameters().getSortMergeJoinMemory() ); } ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java index 78ecacbef26..29614fc0734 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java @@ -32,11 +32,11 @@ public class WorkerMemoryParametersTest @Test public void test_oneWorkerInJvm_alone() { - Assert.assertEquals(params(1, 41, 224_785_000, 100_650_000, 75_000_000), create(1_000_000_000, 1, 1, 1, 0, 0)); - Assert.assertEquals(params(2, 13, 149_410_000, 66_900_000, 75_000_000), create(1_000_000_000, 1, 2, 1, 0, 0)); - Assert.assertEquals(params(4, 3, 89_110_000, 39_900_000, 75_000_000), create(1_000_000_000, 1, 4, 1, 0, 0)); - Assert.assertEquals(params(3, 2, 48_910_000, 21_900_000, 75_000_000), create(1_000_000_000, 1, 8, 1, 0, 0)); - Assert.assertEquals(params(2, 2, 33_448_460, 14_976_922, 75_000_000), create(1_000_000_000, 1, 12, 1, 0, 0)); + Assert.assertEquals(params(335_500_000, 1, 41, 75_000_000), create(1_000_000_000, 1, 1, 1, 0, 0)); + Assert.assertEquals(params(223_000_000, 2, 13, 75_000_000), create(1_000_000_000, 1, 2, 1, 0, 0)); + Assert.assertEquals(params(133_000_000, 4, 3, 75_000_000), create(1_000_000_000, 1, 4, 1, 0, 0)); + Assert.assertEquals(params(73_000_000, 3, 2, 75_000_000), create(1_000_000_000, 1, 8, 1, 0, 0)); + Assert.assertEquals(params(49_923_076, 2, 2, 75_000_000), create(1_000_000_000, 1, 12, 1, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, @@ -54,8 +54,8 @@ public class WorkerMemoryParametersTest @Test public void test_oneWorkerInJvm_twoHundredWorkersInCluster() { - Assert.assertEquals(params(1, 83, 317_580_000, 142_200_000, 150_000_000), create(2_000_000_000, 1, 1, 200, 0, 0)); - Assert.assertEquals(params(2, 27, 166_830_000, 74_700_000, 150_000_000), create(2_000_000_000, 1, 2, 200, 0, 0)); + Assert.assertEquals(params(474_000_000, 1, 83, 150_000_000), create(2_000_000_000, 1, 1, 200, 0, 0)); + Assert.assertEquals(params(249_000_000, 2, 27, 150_000_000), create(2_000_000_000, 1, 2, 200, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, @@ -68,11 +68,11 @@ public class WorkerMemoryParametersTest @Test public void test_fourWorkersInJvm_twoHundredWorkersInCluster() { - Assert.assertEquals(params(1, 150, 679_380_000, 304_200_000, 168_750_000), create(9_000_000_000L, 4, 1, 200, 0, 0)); - Assert.assertEquals(params(2, 62, 543_705_000, 243_450_000, 168_750_000), create(9_000_000_000L, 4, 2, 200, 0, 0)); - Assert.assertEquals(params(4, 22, 374_111_250, 167_512_500, 168_750_000), create(9_000_000_000L, 4, 4, 200, 0, 0)); - Assert.assertEquals(params(4, 14, 204_517_500, 91_575_000, 168_750_000), create(9_000_000_000L, 4, 8, 200, 0, 0)); - Assert.assertEquals(params(4, 8, 68_842_500, 30_825_000, 168_750_000), create(9_000_000_000L, 4, 16, 200, 0, 0)); + Assert.assertEquals(params(1_014_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 200, 0, 0)); + Assert.assertEquals(params(811_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 200, 0, 0)); + Assert.assertEquals(params(558_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 200, 0, 0)); + Assert.assertEquals(params(305_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 200, 0, 0)); + Assert.assertEquals(params(102_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 200, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, @@ -82,7 +82,7 @@ public class WorkerMemoryParametersTest Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault()); // Make sure 124 actually works, and 125 doesn't. (Verify the error message above.) - Assert.assertEquals(params(4, 3, 16_750_000, 7_500_000, 150_000_000), create(8_000_000_000L, 4, 32, 124, 0, 0)); + Assert.assertEquals(params(25_000_000, 4, 3, 150_000_000), create(8_000_000_000L, 4, 32, 124, 0, 0)); final MSQException e2 = Assert.assertThrows( MSQException.class, @@ -96,8 +96,8 @@ public class WorkerMemoryParametersTest public void test_oneWorkerInJvm_smallWorkerCapacity() { // Supersorter max channels per processer are one less than they are usually to account for extra frames that are required while creating composing output channels - Assert.assertEquals(params(1, 3, 27_604_000, 12_360_000, 9_600_000), create(128_000_000, 1, 1, 1, 0, 0)); - Assert.assertEquals(params(1, 1, 17_956_000, 8_040_000, 9_600_000), create(128_000_000, 1, 2, 1, 0, 0)); + Assert.assertEquals(params(41_200_000, 1, 3, 9_600_000), create(128_000_000, 1, 1, 1, 0, 0)); + Assert.assertEquals(params(26_800_000, 1, 1, 9_600_000), create(128_000_000, 1, 2, 1, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, @@ -120,14 +120,10 @@ public class WorkerMemoryParametersTest @Test public void test_fourWorkersInJvm_twoHundredWorkersInCluster_hashPartitions() { - Assert.assertEquals( - params(1, 150, 545_380_000, 244_200_000, 168_750_000), create(9_000_000_000L, 4, 1, 200, 200, 0)); - Assert.assertEquals( - params(2, 62, 409_705_000, 183_450_000, 168_750_000), create(9_000_000_000L, 4, 2, 200, 200, 0)); - Assert.assertEquals( - params(4, 22, 240_111_250, 107_512_500, 168_750_000), create(9_000_000_000L, 4, 4, 200, 200, 0)); - Assert.assertEquals( - params(4, 14, 70_517_500, 31_575_000, 168_750_000), create(9_000_000_000L, 4, 8, 200, 200, 0)); + Assert.assertEquals(params(814_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 200, 200, 0)); + Assert.assertEquals(params(611_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 200, 200, 0)); + Assert.assertEquals(params(358_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 200, 200, 0)); + Assert.assertEquals(params(105_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 200, 200, 0)); final MSQException e = Assert.assertThrows( MSQException.class, @@ -137,7 +133,7 @@ public class WorkerMemoryParametersTest Assert.assertEquals(new TooManyWorkersFault(200, 138), e.getFault()); // Make sure 138 actually works, and 139 doesn't. (Verify the error message above.) - Assert.assertEquals(params(4, 8, 17_922_500, 8_025_000, 168_750_000), create(9_000_000_000L, 4, 16, 138, 138, 0)); + Assert.assertEquals(params(26_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 138, 138, 0)); final MSQException e2 = Assert.assertThrows( MSQException.class, @@ -165,18 +161,16 @@ public class WorkerMemoryParametersTest } private static WorkerMemoryParameters params( + final long processorBundleMemory, final int superSorterMaxActiveProcessors, final int superSorterMaxChannelsPerProcessor, - final long appenderatorMemory, - final long broadcastJoinMemory, final int partitionStatisticsMaxRetainedBytes ) { return new WorkerMemoryParameters( + processorBundleMemory, superSorterMaxActiveProcessors, superSorterMaxChannelsPerProcessor, - appenderatorMemory, - broadcastJoinMemory, partitionStatisticsMaxRetainedBytes ); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java index cfc74d792f8..060b14cec12 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java @@ -20,8 +20,10 @@ package org.apache.druid.msq.querykit.common; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import org.apache.druid.common.config.NullHandling; @@ -46,6 +48,8 @@ import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.StagePartition; @@ -58,8 +62,12 @@ import org.apache.druid.segment.join.JoinTestHelper; import org.apache.druid.segment.join.JoinType; import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.timeline.SegmentId; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Assert; +import org.junit.Assume; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -79,6 +87,7 @@ import java.util.concurrent.TimeUnit; public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest { private static final StagePartition STAGE_PARTITION = new StagePartition(new StageId("q", 0), 0); + private static final long MAX_BUFFERED_BYTES = 10_000_000; private final int rowsPerInputFrame; private final int rowsPerOutputFrame; @@ -154,7 +163,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) ), - JoinType.LEFT + JoinType.LEFT, + MAX_BUFFERED_BYTES ); assertResult(processor, outputChannel.readable(), joinSignature, Collections.emptyList()); @@ -198,7 +208,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) ), - JoinType.LEFT + JoinType.LEFT, + MAX_BUFFERED_BYTES ); final List> expectedRows = Arrays.asList( @@ -273,7 +284,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) ), - JoinType.INNER + JoinType.INNER, + MAX_BUFFERED_BYTES ); assertResult(processor, outputChannel.readable(), joinSignature, Collections.emptyList()); @@ -313,7 +325,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) ), - JoinType.LEFT + JoinType.LEFT, + MAX_BUFFERED_BYTES ); final List> expectedRows = Arrays.asList( @@ -383,7 +396,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest makeFrameWriterFactory(joinSignature), "j0.", ImmutableList.of(Collections.emptyList(), Collections.emptyList()), - JoinType.INNER + JoinType.INNER, + MAX_BUFFERED_BYTES ); final List> expectedRows = Arrays.asList( @@ -495,7 +509,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest new KeyColumn("regionIsoCode", KeyOrder.ASCENDING) ) ), - JoinType.LEFT + JoinType.LEFT, + MAX_BUFFERED_BYTES ); final List> expectedRows = Arrays.asList( @@ -573,7 +588,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)) ), - JoinType.RIGHT + JoinType.RIGHT, + MAX_BUFFERED_BYTES ); final List> expectedRows = Arrays.asList( @@ -654,7 +670,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)) ), - JoinType.FULL + JoinType.FULL, + MAX_BUFFERED_BYTES ); final List> expectedRows = Arrays.asList( @@ -732,7 +749,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)) ), - JoinType.LEFT + JoinType.LEFT, + MAX_BUFFERED_BYTES ); final String countryCodeForNull; @@ -825,7 +843,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)) ), - JoinType.RIGHT + JoinType.RIGHT, + MAX_BUFFERED_BYTES ); final String countryCodeForNull; @@ -918,7 +937,8 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) ), - JoinType.INNER + JoinType.INNER, + MAX_BUFFERED_BYTES ); final List> expectedRows = Arrays.asList( @@ -950,6 +970,234 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); } + @Test + public void testInnerJoinCountryIsoCode_withMaxBufferedBytesLimit_succeeds() throws Exception + { + final ReadableInput factChannel = buildFactInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput countriesChannel = + buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("page", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .add("j0.countryName", ColumnType.STRING) + .add("j0.countryNumber", ColumnType.LONG) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel, + countriesChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.INNER, + 1 + ); + + final List> expectedRows = Arrays.asList( + Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L), + Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L), + Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L), + Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L), + Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 2L), + Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 3L), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", "Ecuador", 4L), + Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L), + Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L), + Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L), + Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L), + Arrays.asList("青野武", "JP", "JP", "Japan", 8L), + Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L), + Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L), + Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L), + Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L), + Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L), + Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L), + Arrays.asList("Carlo Curti", "US", "US", "United States", 13L), + Arrays.asList("DirecTV", "US", "US", "United States", 13L), + Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 13L), + Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L), + Arrays.asList("President of India", "US", "US", "United States", 13L) + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testInnerJoinCountryIsoCode_backwards_withMaxBufferedBytesLimit_succeeds() throws Exception + { + final ReadableInput factChannel = buildFactInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput countriesChannel = + buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("j0.page", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .add("countryName", ColumnType.STRING) + .add("countryNumber", ColumnType.LONG) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + countriesChannel, + factChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.INNER, + 1 + ); + + final List> expectedRows = Arrays.asList( + Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L), + Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L), + Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L), + Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L), + Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 2L), + Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 3L), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", "Ecuador", 4L), + Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L), + Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L), + Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L), + Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L), + Arrays.asList("青野武", "JP", "JP", "Japan", 8L), + Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L), + Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L), + Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L), + Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L), + Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L), + Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L), + Arrays.asList("Carlo Curti", "US", "US", "United States", 13L), + Arrays.asList("DirecTV", "US", "US", "United States", 13L), + Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 13L), + Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L), + Arrays.asList("President of India", "US", "US", "United States", 13L) + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testCountrySelfJoin() throws Exception + { + final ReadableInput factChannel1 = buildFactInput(ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING))); + final ReadableInput factChannel2 = buildFactInput(ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING))); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("channel", ColumnType.STRING) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel1, + factChannel2, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING)) + ), + JoinType.INNER, + MAX_BUFFERED_BYTES + ); + + final List> expectedRows = new ArrayList<>(); + + final ImmutableMap expectedCounts = + ImmutableMap.builder() + .put("#ca.wikipedia", 1L) + .put("#de.wikipedia", 1L) + .put("#en.wikipedia", 196L) + .put("#es.wikipedia", 16L) + .put("#fr.wikipedia", 9L) + .put("#ja.wikipedia", 1L) + .put("#ko.wikipedia", 1L) + .put("#ru.wikipedia", 1L) + .put("#vi.wikipedia", 9L) + .build(); + + for (final Map.Entry entry : expectedCounts.entrySet()) { + for (int i = 0; i < Ints.checkedCast(entry.getValue()); i++) { + expectedRows.add(Collections.singletonList(entry.getKey())); + } + } + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testCountrySelfJoin_withMaxBufferedBytesLimit_fails() throws Exception + { + // Test is only valid when rowsPerInputFrame is low enough that we get multiple frames. + Assume.assumeThat(rowsPerInputFrame, Matchers.lessThanOrEqualTo(7)); + + final ReadableInput factChannel1 = buildFactInput(ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING))); + final ReadableInput factChannel2 = buildFactInput(ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING))); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("channel", ColumnType.STRING) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel1, + factChannel2, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING)) + ), + JoinType.INNER, + 1 + ); + + final RuntimeException e = Assert.assertThrows( + RuntimeException.class, + () -> run(processor, outputChannel.readable(), joinSignature) + ); + + MatcherAssert.assertThat(e.getCause(), CoreMatchers.instanceOf(RuntimeException.class)); + MatcherAssert.assertThat(e.getCause().getCause(), CoreMatchers.instanceOf(MSQException.class)); + MatcherAssert.assertThat( + ((MSQException) e.getCause().getCause()).getFault(), + CoreMatchers.instanceOf(TooManyRowsWithSameKeyFault.class) + ); + } + private void assertResult( final SortMergeJoinFrameProcessor processor, final ReadableFrameChannel readableOutputChannel, @@ -957,14 +1205,25 @@ public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest final List> expectedRows ) { - final ListenableFuture retVal = exec.runFully(processor, null); + final List> rowsFromProcessor = run(processor, readableOutputChannel, joinSignature); + FrameTestUtil.assertRowsEqual(Sequences.simple(expectedRows), Sequences.simple(rowsFromProcessor)); + } + + private List> run( + final SortMergeJoinFrameProcessor processor, + final ReadableFrameChannel readableOutputChannel, + final RowSignature joinSignature + ) + { + final ListenableFuture retValFromProcessor = exec.runFully(processor, null); final Sequence> rowsFromProcessor = FrameTestUtil.readRowsFromFrameChannel( readableOutputChannel, FrameReader.create(joinSignature) ); - FrameTestUtil.assertRowsEqual(Sequences.simple(expectedRows), rowsFromProcessor); - Assert.assertEquals(0L, (long) FutureUtils.getUnchecked(retVal, true)); + final List> rows = rowsFromProcessor.toList(); + Assert.assertEquals(0L, (long) FutureUtils.getUnchecked(retValFromProcessor, true)); + return rows; } private ReadableInput buildFactInput(final List keyColumns) throws IOException