Faster k-way merging using tournament trees, 8-byte key strides. (#15661)

* Faster k-way merging using tournament trees, 8-byte key strides.

Two speedups for FrameChannelMerger (which does k-way merging in MSQ):

1) Replace the priority queue with a tournament tree, which does fewer
   comparisons.

2) Compare keys using 8-byte strides, rather than 1 byte at a time.

* Adjust comments.

* Fix style.

* Adjust benchmark and test.

* Add eight-list test (power of two).
This commit is contained in:
Gian Merlino 2024-01-11 08:36:22 -08:00 committed by GitHub
parent 2118258b54
commit 2231cb30a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 845 additions and 48 deletions

View File

@ -0,0 +1,353 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark.frame;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory;
import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.processor.FrameChannelMerger;
import org.apache.druid.frame.processor.FrameProcessorExecutor;
import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.testutil.FrameSequenceBuilder;
import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.segment.RowBasedSegment;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.timeline.SegmentId;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/**
* Benchmark for {@link FrameChannelMerger}.
*/
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 3)
@Measurement(iterations = 5)
public class FrameChannelMergerBenchmark
{
static {
NullHandling.initializeForTests();
}
private static final String KEY = "key";
private static final String VALUE = "value";
@Param({"5000000"})
private int numRows;
@Param({"2", "16"})
private int numChannels;
@Param({"20"})
private int keyLength;
@Param({"100"})
private int rowLength;
/**
* Linked to {@link KeyGenerator}.
*/
@Param({"random", "sequential"})
private String keyGeneratorString;
/**
* Linked to {@link ChannelDistribution}.
*/
@Param({"round_robin", "clustered"})
private String channelDistributionString;
/**
* Generator of keys.
*/
enum KeyGenerator
{
/**
* Random characters from a-z.
*/
RANDOM {
@Override
public String generateKey(int rowNumber, int keyLength)
{
final StringBuilder builder = new StringBuilder(keyLength);
for (int i = 0; i < keyLength; i++) {
builder.append((char) ('a' + ThreadLocalRandom.current().nextInt(26)));
}
return builder.toString();
}
},
/**
* Sequential with zero-padding.
*/
SEQUENTIAL {
@Override
public String generateKey(int rowNumber, int keyLength)
{
return StringUtils.format("%0" + keyLength + "d", rowNumber);
}
};
public abstract String generateKey(int rowNumber, int keyLength);
}
/**
* Distribution of rows across channels.
*/
enum ChannelDistribution
{
/**
* Sequential keys are distributed round-robin to channels.
*/
ROUND_ROBIN {
@Override
public int getChannelNumber(int rowNumber, int numRows, int numChannels)
{
return rowNumber % numChannels;
}
},
/**
* Sequential keys are clustered into the same channels.
*/
CLUSTERED {
@Override
public int getChannelNumber(int rowNumber, int numRows, int numChannels)
{
final int rowsPerChannel = numRows / numChannels;
return rowNumber / rowsPerChannel;
}
};
public abstract int getChannelNumber(int rowNumber, int numRows, int numChannels);
}
private final RowSignature signature =
RowSignature.builder()
.add(KEY, ColumnType.STRING)
.add(VALUE, ColumnType.STRING)
.build();
private final FrameReader frameReader = FrameReader.create(signature);
private final List<KeyColumn> sortKey = ImmutableList.of(new KeyColumn(KEY, KeyOrder.ASCENDING));
private List<List<Frame>> channelFrames;
private FrameProcessorExecutor exec;
private List<BlockingQueueFrameChannel> channels;
/**
* Create {@link #numChannels} channels in {@link #channels}, with {@link #numRows} total rows split across the
* channels according to {@link ChannelDistribution}. Each channel is individually sorted, as required
* by {@link FrameChannelMerger}.
*
* Rows are fixed-length at {@link #rowLength} with fixed-length keys at {@link #keyLength}. Keys are generated
* by {@link KeyGenerator}.
*/
@Setup(Level.Trial)
public void setupTrial()
{
exec = new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat(getClass().getSimpleName()))
)
);
final KeyGenerator keyGenerator = KeyGenerator.valueOf(StringUtils.toUpperCase(keyGeneratorString));
final ChannelDistribution channelDistribution =
ChannelDistribution.valueOf(StringUtils.toUpperCase(channelDistributionString));
// Create channelRows which holds rows for each channel.
final List<List<NonnullPair<String, String>>> channelRows = new ArrayList<>();
channelFrames = new ArrayList<>();
for (int channelNumber = 0; channelNumber < numChannels; channelNumber++) {
channelRows.add(new ArrayList<>());
channelFrames.add(new ArrayList<>());
}
// Create "valueString", a string full of spaces to pad out the row.
final StringBuilder valueStringBuilder = new StringBuilder();
for (int i = 0; i < rowLength - keyLength; i++) {
valueStringBuilder.append(' ');
}
final String valueString = valueStringBuilder.toString();
// Populate "channelRows".
for (int rowNumber = 0; rowNumber < numRows; rowNumber++) {
final String keyString = keyGenerator.generateKey(rowNumber, keyLength);
final NonnullPair<String, String> row = new NonnullPair<>(keyString, valueString);
channelRows.get(channelDistribution.getChannelNumber(rowNumber, numRows, numChannels)).add(row);
}
// Sort each "channelRows".
for (List<NonnullPair<String, String>> rows : channelRows) {
rows.sort(Comparator.comparing(row -> row.lhs));
}
// Populate each "channelFrames".
for (int channelNumber = 0; channelNumber < numChannels; channelNumber++) {
final List<NonnullPair<String, String>> rows = channelRows.get(channelNumber);
final RowBasedSegment<NonnullPair<String, String>> segment =
new RowBasedSegment<>(
SegmentId.dummy("__dummy"),
Sequences.simple(rows),
columnName -> {
if (KEY.equals(columnName)) {
return row -> row.lhs;
} else if (VALUE.equals(columnName)) {
return row -> row.rhs;
} else if (ColumnHolder.TIME_COLUMN_NAME.equals(columnName)) {
return row -> 0L;
} else {
throw new ISE("No such column[%s]", columnName);
}
},
signature
);
final Sequence<Frame> frameSequence =
FrameSequenceBuilder.fromAdapter(segment.asStorageAdapter())
.allocator(ArenaMemoryAllocator.createOnHeap(10_000_000))
.frameType(FrameType.ROW_BASED)
.frames();
final List<Frame> channelFrameList = channelFrames.get(channelNumber);
frameSequence.forEach(channelFrameList::add);
rows.clear();
}
}
/**
* Create {@link #numChannels} channels in {@link #channels}, with {@link #numRows} total rows split across the
* channels according to {@link ChannelDistribution}. Each channel is individually sorted, as required
* by {@link FrameChannelMerger}.
*
* Rows are fixed-length at {@link #rowLength} with fixed-length keys at {@link #keyLength}. Keys are generated
* by {@link KeyGenerator}.
*/
@Setup(Level.Invocation)
public void setupInvocation() throws IOException
{
exec = new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat(getClass().getSimpleName()))
)
);
// Create channels.
channels = new ArrayList<>(numChannels);
for (int channelNumber = 0; channelNumber < numChannels; channelNumber++) {
channels.add(new BlockingQueueFrameChannel(100));
}
// Populate each channel.
for (int channelNumber = 0; channelNumber < numChannels; channelNumber++) {
final List<Frame> frames = channelFrames.get(channelNumber);
final WritableFrameChannel writableChannel = channels.get(channelNumber).writable();
for (Frame frame : frames) {
writableChannel.write(frame);
}
}
// Close all channels.
for (BlockingQueueFrameChannel channel : channels) {
channel.writable().close();
}
}
@TearDown(Level.Trial)
public void tearDown() throws Exception
{
exec.getExecutorService().shutdownNow();
if (!exec.getExecutorService().awaitTermination(1, TimeUnit.MINUTES)) {
throw new ISE("Could not terminate executor after 1 minute");
}
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void mergeChannels(Blackhole blackhole)
{
final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
final FrameChannelMerger merger = new FrameChannelMerger(
channels.stream().map(BlockingQueueFrameChannel::readable).collect(Collectors.toList()),
frameReader,
outputChannel.writable(),
FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED,
new ArenaMemoryAllocatorFactory(1_000_000),
signature,
sortKey
),
sortKey,
null,
-1
);
final ListenableFuture<Long> retVal = exec.runFully(merger, null);
while (!outputChannel.readable().isFinished()) {
FutureUtils.getUnchecked(outputChannel.readable().readabilityFuture(), false);
if (outputChannel.readable().canRead()) {
final Frame frame = outputChannel.readable().read();
blackhole.consume(frame);
}
}
if (FutureUtils.getUnchecked(retVal, true) != numRows) {
throw new ISE("Incorrect numRows[%s], expected[%s]", FutureUtils.getUncheckedImmediately(retVal), numRows);
}
}
}

View File

@ -141,8 +141,8 @@ public class MSQLoadedSegmentTests extends MSQTestBase
.setExpectedResultRows(ImmutableList.of(
new Object[]{1L, ""},
new Object[]{1L, "qwe"},
new Object[]{1L, "10.1"},
new Object[]{1L, "tyu"},
new Object[]{1L, "10.1"},
new Object[]{1L, "2"},
new Object[]{1L, "1"},
new Object[]{1L, "def"},

View File

@ -47,7 +47,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
private final Memory dataRegion;
private final int keyFieldCount;
private final List<FieldReader> keyFieldReaders;
private final long firstFieldPosition;
private final int firstFieldPosition;
private final int[] ascDescRunLengths;
private FrameComparisonWidgetImpl(
@ -56,7 +56,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
final Memory rowOffsetRegion,
final Memory dataRegion,
final List<FieldReader> keyFieldReaders,
final long firstFieldPosition,
final int firstFieldPosition,
final int[] ascDescRunLengths
)
{
@ -218,8 +218,8 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
final long rowPosition = getRowPositionInDataRegion(row);
final long otherRowPosition = otherWidgetImpl.getRowPositionInDataRegion(otherRow);
long comparableBytesStartPositionInRow = firstFieldPosition;
long otherComparableBytesStartPositionInRow = otherWidgetImpl.firstFieldPosition;
int comparableBytesStartPositionInRow = firstFieldPosition;
int otherComparableBytesStartPositionInRow = otherWidgetImpl.firstFieldPosition;
boolean ascending = true;
int field = 0;
@ -227,12 +227,12 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
for (int numFields : ascDescRunLengths) {
if (numFields > 0) {
final int nextField = field + numFields;
final long comparableBytesEndPositionInRow = getFieldEndPositionInRow(rowPosition, nextField - 1);
final long otherComparableBytesEndPositionInRow =
final int comparableBytesEndPositionInRow = getFieldEndPositionInRow(rowPosition, nextField - 1);
final int otherComparableBytesEndPositionInRow =
otherWidgetImpl.getFieldEndPositionInRow(otherRowPosition, nextField - 1);
final long comparableBytesLength = comparableBytesEndPositionInRow - comparableBytesStartPositionInRow;
final long otherComparableBytesLength =
final int comparableBytesLength = comparableBytesEndPositionInRow - comparableBytesStartPositionInRow;
final int otherComparableBytesLength =
otherComparableBytesEndPositionInRow - otherComparableBytesStartPositionInRow;
int cmp = FrameReaderUtils.compareMemoryUnsigned(
@ -270,7 +270,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
}
}
long getFieldEndPositionInRow(final long rowPosition, final int fieldNumber)
int getFieldEndPositionInRow(final long rowPosition, final int fieldNumber)
{
assert fieldNumber >= 0 && fieldNumber < signature.size();
return dataRegion.getInt(rowPosition + (long) fieldNumber * Integer.BYTES);

View File

@ -19,9 +19,7 @@
package org.apache.druid.frame.processor;
import it.unimi.dsi.fastutil.ints.IntHeapPriorityQueue;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntPriorityQueue;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.Frame;
import org.apache.druid.frame.channel.FrameWithPartition;
@ -35,7 +33,6 @@ import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.Cursor;
@ -65,12 +62,13 @@ public class FrameChannelMerger implements FrameProcessor<Long>
private final FrameReader frameReader;
private final List<KeyColumn> sortKey;
private final ClusterByPartitions partitions;
private final IntPriorityQueue priorityQueue;
private final TournamentTree tournamentTree;
private final FrameWriterFactory frameWriterFactory;
private final FramePlus[] currentFrames;
private final long rowLimit;
private long rowsOutput = 0;
private int currentPartition = 0;
private int remainingChannels;
// ColumnSelectorFactory that always reads from the current row in the merged sequence.
final MultiColumnSelectorFactory mergedColumnSelectorFactory;
@ -111,13 +109,27 @@ public class FrameChannelMerger implements FrameProcessor<Long>
this.partitions = partitionsToUse;
this.rowLimit = rowLimit;
this.currentFrames = new FramePlus[inputChannels.size()];
this.priorityQueue = new IntHeapPriorityQueue(
this.remainingChannels = 0;
this.tournamentTree = new TournamentTree(
inputChannels.size(),
(k1, k2) -> currentFrames[k1].comparisonWidget.compare(
currentFrames[k1].rowNumber,
currentFrames[k2].comparisonWidget,
currentFrames[k2].rowNumber
)
(k1, k2) -> {
final FramePlus frame1 = currentFrames[k1];
final FramePlus frame2 = currentFrames[k2];
if (frame1 == frame2) {
return 0;
} else if (frame1 == null) {
return 1;
} else if (frame2 == null) {
return -1;
} else {
return currentFrames[k1].comparisonWidget.compare(
currentFrames[k1].rowNumber,
currentFrames[k2].comparisonWidget,
currentFrames[k2].rowNumber
);
}
}
);
final List<Supplier<ColumnSelectorFactory>> frameColumnSelectorFactorySuppliers =
@ -149,13 +161,13 @@ public class FrameChannelMerger implements FrameProcessor<Long>
@Override
public ReturnOrAwait<Long> runIncrementally(final IntSet readableInputs) throws IOException
{
final IntSet awaitSet = populateCurrentFramesAndPriorityQueue();
final IntSet awaitSet = populateCurrentFramesAndTournamentTree();
if (!awaitSet.isEmpty()) {
return ReturnOrAwait.awaitAll(awaitSet);
}
if (priorityQueue.isEmpty()) {
if (finished()) {
// Done!
return ReturnOrAwait.returnObject(rowsOutput);
}
@ -167,7 +179,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
private FrameWithPartition nextFrame()
{
if (priorityQueue.isEmpty()) {
if (finished()) {
throw new NoSuchElementException();
}
@ -175,8 +187,8 @@ public class FrameChannelMerger implements FrameProcessor<Long>
int mergedFramePartition = currentPartition;
RowKey currentPartitionEnd = partitions.get(currentPartition).getEnd();
while (!priorityQueue.isEmpty()) {
final int currentChannel = priorityQueue.firstInt();
while (!finished()) {
final int currentChannel = tournamentTree.getMin();
mergedColumnSelectorFactory.setCurrentFactory(currentChannel);
if (currentPartitionEnd != null) {
@ -206,31 +218,24 @@ public class FrameChannelMerger implements FrameProcessor<Long>
throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
}
// Frame is full. Don't touch the priority queue; instead, return the current frame.
// Frame is full. Return the current frame.
break;
}
if (rowLimit != UNLIMITED && rowsOutput >= rowLimit) {
// Limit reached; we're done.
priorityQueue.clear();
Arrays.fill(currentFrames, null);
remainingChannels = 0;
} else {
// Continue populating the priority queue.
if (currentChannel != priorityQueue.dequeueInt()) {
// There's a bug in this function. Nothing sensible we can really include in this error message.
throw new ISE("Unexpected channel");
}
// Continue reading the currentChannel.
final FramePlus channelFramePlus = currentFrames[currentChannel];
channelFramePlus.advance();
if (!channelFramePlus.cursor.isDone()) {
// Add this channel back to the priority queue, so it pops back out at the right time.
priorityQueue.enqueue(currentChannel);
} else {
if (channelFramePlus.cursor.isDone()) {
// Done reading current frame from "channel".
// Clear it and see if there is another one available for immediate loading.
currentFrames[currentChannel] = null;
remainingChannels--;
final ReadableFrameChannel channel = inputChannels.get(currentChannel);
@ -238,7 +243,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
// Read next frame from this channel.
final Frame frame = channel.read();
currentFrames[currentChannel] = new FramePlus(frame, frameReader, sortKey);
priorityQueue.enqueue(currentChannel);
remainingChannels++;
} else if (channel.isFinished()) {
// Done reading this channel. Fall through and continue with other channels.
} else {
@ -254,6 +259,11 @@ public class FrameChannelMerger implements FrameProcessor<Long>
}
}
private boolean finished()
{
return remainingChannels == 0;
}
@Override
public void cleanup() throws IOException
{
@ -264,7 +274,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
* Populates {@link #currentFrames}, wherever necessary, from any readable input channels. Returns the set of
* channels that are required for population but are not readable.
*/
private IntSet populateCurrentFramesAndPriorityQueue()
private IntSet populateCurrentFramesAndTournamentTree()
{
final IntSet await = new IntOpenHashSet();
@ -275,7 +285,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
if (channel.canRead()) {
final Frame frame = channel.read();
currentFrames[i] = new FramePlus(frame, frameReader, sortKey);
priorityQueue.enqueue(i);
remainingChannels++;
} else if (!channel.isFinished()) {
await.add(i);
}

View File

@ -0,0 +1,211 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.processor;
import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.ints.IntComparator;
import org.apache.druid.java.util.common.IAE;
import java.util.Arrays;
/**
* Tree-of-losers tournament tree used for K-way merging. The tree contains a fixed set of elements, from 0 (inclusive)
* to {@link #numElements} (exclusive).
*
* The tree represents a tournament played amongst the elements. At all times each node of the tree contains the loser
* of the match at that node. The winners of the matches are not explicitly stored, except for the overall winner of
* the tournament, which is stored in {@code tree[0]}.
*
* When used as part of k-way merge, expected usage is call {@link #getMin()} to retrieve a run number, then read
* an element from the run. On the next call to {@link #getMin()}, the tree internally calls {@link #update()} to
* handle the case where the min needs to change.
*
* Refer to https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree for additional details.
*/
public class TournamentTree
{
/**
* Complete binary tree, with the overall winner (least element) in slot 0, the root of the loser tree in slot 1, and
* otherwise the node in slot i has children in slots 2*i and (2*i)+1. The final layer of the tree, containing the
* actual elements [0..numElements), is not stored in this array (it is implicit).
*/
private final int[] tree;
/**
* Number of elements in the tree.
*/
private final int numElements;
/**
* Number of elements, rounded up to the nearest power of two.
*/
private final int numElementsRounded;
/**
* Comparator for the elements of the tree.
*/
private final IntComparator comparator;
/**
* Whether this tree has been initialized.
*/
private boolean initialized;
/**
* Creates a tree with a certain number of elements.
*
* @param numElements number of elements in the tree
* @param comparator comparator for the elements. Smaller elements "win".
*/
public TournamentTree(final int numElements, final IntComparator comparator)
{
if (numElements < 1) {
throw new IAE("Must have at least one element");
}
this.numElements = numElements;
this.numElementsRounded = HashCommon.nextPowerOfTwo(numElements);
this.comparator = comparator;
this.tree = new int[numElementsRounded];
}
/**
* Get the current minimum element (the overall winner, i.e., the run to pull the next element from in the
* K-way merge).
*/
public int getMin()
{
if (!initialized) {
// Defer initialization until the first getMin() call, since the tree object might be created before the
// comparator is fully valid. (The comparator is typically not valid until at least one row is available
// from each run.)
initialize();
initialized = true;
}
update();
return tree[0];
}
@Override
public String toString()
{
return "TournamentTree{" +
"numElements=" + numElementsRounded +
", tree=" + Arrays.toString(tree) +
'}';
}
/**
* Returns the backing array of the tree. Used in tests.
*/
int[] backingArray()
{
return tree;
}
/**
* Initializes the tree by running a full tournament. At the conclusion of this method, all nodes of {@link #tree}
* are filled in with the loser for the "game" played at that node, except for {@code tree[0]}, which contains the
* overall winner (least element).
*/
private void initialize()
{
if (numElements == 1) {
return;
}
// Allocate a winner tree, which stores the winner in each node (rather than loser). We'll use this temporarily in
// this method, but it won't be stored long-term.
final int[] winnerTree = new int[numElementsRounded];
// Populate the lowest layer of the loser and winner trees. For example: with elements 0, 1, 2, 3, we'll
// compare 0 vs 1 and 2 vs 3.
for (int i = 0; i < numElementsRounded; i += 2) {
final int cmp = compare(i, i + 1);
final int loser, winner;
if (cmp <= 0) {
winner = i;
loser = i + 1;
} else {
winner = i + 1;
loser = i;
}
final int nodeIndex = (tree.length + i) >> 1;
tree[nodeIndex] = loser;
winnerTree[nodeIndex] = winner;
}
// Populate all other layers of the loser and winner trees.
for (int layerSize = numElementsRounded >> 1; layerSize > 1; layerSize >>= 1) {
for (int i = 0; i < layerSize; i += 2) {
// Size of a layer is also the starting offset of the layer, so node i of this layer is at layerSize + i.
final int left = winnerTree[layerSize + i];
final int right = winnerTree[layerSize + i + 1];
final int cmp = compare(left, right);
final int loser, winner;
if (cmp <= 0) {
winner = left;
loser = right;
} else {
winner = right;
loser = left;
}
final int nodeIndex = (layerSize + i) >> 1;
tree[nodeIndex] = loser;
winnerTree[nodeIndex] = winner;
}
}
// Populate tree[0], overall winner; discard winnerTree.
tree[0] = winnerTree[1];
}
/**
* Re-play the tournament from leaf to root, assuming the winner (stored in {@code tree[0]} may have changed its
* ordering relative to other elements.
*/
private void update()
{
int current = tree[0];
for (int nodeIndex = ((current & ~1) + tree.length) >> 1; nodeIndex >= 1; nodeIndex >>= 1) {
int nodeLoser = tree[nodeIndex];
final int cmp = compare(current, nodeLoser);
if (cmp > 0) {
tree[nodeIndex] = current;
current = nodeLoser;
}
}
tree[0] = current;
}
/**
* Compare two elements, which may be outside {@link #numElements}.
*/
private int compare(int a, int b)
{
if (b >= numElements || a >= numElements) {
return Integer.compare(a, b);
} else {
return comparator.compare(a, b);
}
}
}

View File

@ -117,24 +117,51 @@ public class FrameReaderUtils
public static int compareMemoryUnsigned(
final Memory memory1,
final long position1,
final long length1,
final int length1,
final Memory memory2,
final long position2,
final long length2
final int length2
)
{
final long commonLength = Math.min(length1, length2);
final int commonLength = Math.min(length1, length2);
for (long i = 0; i < commonLength; i++) {
final byte byte1 = memory1.getByte(position1 + i);
final byte byte2 = memory2.getByte(position2 + i);
final int cmp = (byte1 & 0xFF) - (byte2 & 0xFF); // Unsigned comparison
for (int i = 0; i < commonLength; i += Long.BYTES) {
final int remaining = commonLength - i;
final long r1 = readComparableLong(memory1, position1 + i, remaining);
final long r2 = readComparableLong(memory2, position2 + i, remaining);
final int cmp = Long.compare(r1, r2);
if (cmp != 0) {
return cmp;
}
}
return Long.compare(length1, length2);
return Integer.compare(length1, length2);
}
public static long readComparableLong(final Memory memory, final long position, final int length)
{
long retVal = 0;
switch (length) {
case 7:
retVal |= (memory.getByte(position + 6) & 0xFFL) << 8;
case 6:
retVal |= (memory.getByte(position + 5) & 0xFFL) << 16;
case 5:
retVal |= (memory.getByte(position + 4) & 0xFFL) << 24;
case 4:
retVal |= (memory.getByte(position + 3) & 0xFFL) << 32;
case 3:
retVal |= (memory.getByte(position + 2) & 0xFFL) << 40;
case 2:
retVal |= (memory.getByte(position + 1) & 0xFFL) << 48;
case 1:
retVal |= (memory.getByte(position) & 0xFFL) << 56;
break;
default:
retVal = Long.reverseBytes(memory.getLong(position));
}
return retVal + Long.MIN_VALUE;
}
/**

View File

@ -0,0 +1,196 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.processor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import it.unimi.dsi.fastutil.ints.IntComparator;
import it.unimi.dsi.fastutil.ints.IntComparators;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
public class TournamentTreeTest
{
@Test
public void test_construction_oneElement()
{
final IntComparator intComparator = IntComparators.NATURAL_COMPARATOR;
final TournamentTree tree = new TournamentTree(1, intComparator);
Assert.assertEquals(0, tree.getMin());
Assert.assertArrayEquals(
"construction",
new int[]{0},
tree.backingArray()
);
}
@Test
public void test_construction_tenElements_natural()
{
final IntComparator intComparator = IntComparators.NATURAL_COMPARATOR;
final TournamentTree tree = new TournamentTree(10, intComparator);
Assert.assertEquals(0, tree.getMin());
Assert.assertArrayEquals(
"construction",
new int[]{0, 8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15},
tree.backingArray()
);
}
@Test
public void test_construction_tenElements_reverse()
{
final IntComparator intComparator = IntComparators.OPPOSITE_COMPARATOR;
final TournamentTree tree = new TournamentTree(10, intComparator);
Assert.assertEquals(9, tree.getMin());
Assert.assertArrayEquals(
"construction",
new int[]{9, 7, 3, 12, 1, 5, 10, 14, 0, 2, 4, 6, 8, 11, 13, 15},
tree.backingArray()
);
}
@Test
public void test_construction_sixteenElements_reverse()
{
final IntComparator intComparator = IntComparators.OPPOSITE_COMPARATOR;
final TournamentTree tree = new TournamentTree(16, intComparator);
Assert.assertEquals(15, tree.getMin());
Assert.assertArrayEquals(
"construction",
new int[]{15, 7, 3, 11, 1, 5, 9, 13, 0, 2, 4, 6, 8, 10, 12, 14},
tree.backingArray()
);
}
@Test
public void test_merge_eightLists()
{
final List<List<Integer>> lists = ImmutableList.of(
ImmutableList.of(0, 1, 1, 5),
ImmutableList.of(0, 4),
ImmutableList.of(1, 5, 5, 6, 9),
ImmutableList.of(1, 6, 7, 8),
ImmutableList.of(2, 2, 3, 5, 7),
ImmutableList.of(0, 2, 4, 8, 9),
ImmutableList.of(1, 2, 4, 6, 7, 7),
ImmutableList.of(1, 3, 6, 7, 7)
);
final List<Deque<Integer>> queues = new ArrayList<>();
for (final List<Integer> list : lists) {
final Deque<Integer> queue = new ArrayDeque<>();
queues.add(queue);
for (int i : list) {
queue.addLast(i);
}
}
final IntComparator intComparator = (a, b) -> {
final Integer itemA = queues.get(a).peek();
final Integer itemB = queues.get(b).peek();
return Ordering.natural().nullsLast().compare(itemA, itemB);
};
final TournamentTree tree = new TournamentTree(lists.size(), intComparator);
final List<Integer> intsRead = new ArrayList<>();
while (queues.get(tree.getMin()).peek() != null) {
intsRead.add(queues.get(tree.getMin()).poll());
}
final List<Integer> expected = new ArrayList<>();
expected.addAll(Arrays.asList(0, 0, 0));
expected.addAll(Arrays.asList(1, 1, 1, 1, 1, 1));
expected.addAll(Arrays.asList(2, 2, 2, 2));
expected.addAll(Arrays.asList(3, 3));
expected.addAll(Arrays.asList(4, 4, 4));
expected.addAll(Arrays.asList(5, 5, 5, 5));
expected.addAll(Arrays.asList(6, 6, 6, 6));
expected.addAll(Arrays.asList(7, 7, 7, 7, 7, 7));
expected.addAll(Arrays.asList(8, 8));
expected.addAll(Arrays.asList(9, 9));
Assert.assertEquals(expected, intsRead);
}
@Test
public void test_merge_tenLists()
{
final List<List<Integer>> lists = ImmutableList.of(
ImmutableList.of(0, 1, 1, 5),
ImmutableList.of(0, 4),
ImmutableList.of(1, 5, 5, 6, 9),
ImmutableList.of(1, 6, 7, 8),
ImmutableList.of(2, 2, 3, 5, 7),
ImmutableList.of(0, 2, 4, 8, 9),
ImmutableList.of(1, 2, 4, 6, 7, 7),
ImmutableList.of(1, 3, 6, 7, 7),
ImmutableList.of(1, 3, 3, 4, 5, 6),
ImmutableList.of(4, 4, 6, 7)
);
final List<Deque<Integer>> queues = new ArrayList<>();
for (final List<Integer> list : lists) {
final Deque<Integer> queue = new ArrayDeque<>();
queues.add(queue);
for (int i : list) {
queue.addLast(i);
}
}
final IntComparator intComparator = (a, b) -> {
final Integer itemA = queues.get(a).peek();
final Integer itemB = queues.get(b).peek();
return Ordering.natural().nullsLast().compare(itemA, itemB);
};
final TournamentTree tree = new TournamentTree(lists.size(), intComparator);
final List<Integer> intsRead = new ArrayList<>();
while (queues.get(tree.getMin()).peek() != null) {
intsRead.add(queues.get(tree.getMin()).poll());
}
final List<Integer> expected = new ArrayList<>();
expected.addAll(Arrays.asList(0, 0, 0));
expected.addAll(Arrays.asList(1, 1, 1, 1, 1, 1, 1));
expected.addAll(Arrays.asList(2, 2, 2, 2));
expected.addAll(Arrays.asList(3, 3, 3, 3));
expected.addAll(Arrays.asList(4, 4, 4, 4, 4, 4));
expected.addAll(Arrays.asList(5, 5, 5, 5, 5));
expected.addAll(Arrays.asList(6, 6, 6, 6, 6, 6));
expected.addAll(Arrays.asList(7, 7, 7, 7, 7, 7, 7));
expected.addAll(Arrays.asList(8, 8));
expected.addAll(Arrays.asList(9, 9));
Assert.assertEquals(expected, intsRead);
}
}