Modify quantile sketches to add byte[] directly (#13351)

* Modify quantile sketchs to add byte[] directly

* Rename class and add test
This commit is contained in:
Adarsh Sanjeev 2022-11-14 00:24:06 +05:30 committed by GitHub
parent 81d005f267
commit a3edda3b63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 411 additions and 148 deletions

View File

@ -44,13 +44,13 @@ import java.util.NoSuchElementException;
*/ */
public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketchKeyCollector> public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketchKeyCollector>
{ {
private final Comparator<RowKey> comparator; private final Comparator<byte[]> comparator;
private ItemsSketch<RowKey> sketch; private ItemsSketch<byte[]> sketch;
private double averageKeyLength; private double averageKeyLength;
QuantilesSketchKeyCollector( QuantilesSketchKeyCollector(
final Comparator<RowKey> comparator, final Comparator<byte[]> comparator,
@Nullable final ItemsSketch<RowKey> sketch, @Nullable final ItemsSketch<byte[]> sketch,
double averageKeyLength double averageKeyLength
) )
{ {
@ -67,7 +67,7 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
estimatedTotalSketchSizeInBytes += key.estimatedObjectSizeBytes() * weight; estimatedTotalSketchSizeInBytes += key.estimatedObjectSizeBytes() * weight;
for (int i = 0; i < weight; i++) { for (int i = 0; i < weight; i++) {
// Add the same key multiple times to make it "heavier". // Add the same key multiple times to make it "heavier".
sketch.update(key); sketch.update(key.array());
} }
averageKeyLength = (estimatedTotalSketchSizeInBytes / sketch.getN()); averageKeyLength = (estimatedTotalSketchSizeInBytes / sketch.getN());
} }
@ -75,7 +75,7 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
@Override @Override
public void addAll(QuantilesSketchKeyCollector other) public void addAll(QuantilesSketchKeyCollector other)
{ {
final ItemsUnion<RowKey> union = ItemsUnion.getInstance( final ItemsUnion<byte[]> union = ItemsUnion.getInstance(
Math.max(sketch.getK(), other.sketch.getK()), Math.max(sketch.getK(), other.sketch.getK()),
comparator comparator
); );
@ -129,10 +129,10 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
@Override @Override
public RowKey minKey() public RowKey minKey()
{ {
final RowKey minValue = sketch.getMinValue(); final byte[] minValue = sketch.getMinValue();
if (minValue != null) { if (minValue != null) {
return minValue; return RowKey.wrap(minValue);
} else { } else {
throw new NoSuchElementException(); throw new NoSuchElementException();
} }
@ -152,20 +152,20 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
final int numPartitions = Ints.checkedCast(LongMath.divide(sketch.getN(), targetWeight, RoundingMode.CEILING)); final int numPartitions = Ints.checkedCast(LongMath.divide(sketch.getN(), targetWeight, RoundingMode.CEILING));
// numPartitions + 1, because the final quantile is the max, and we won't build a partition based on that. // numPartitions + 1, because the final quantile is the max, and we won't build a partition based on that.
final RowKey[] quantiles = sketch.getQuantiles(numPartitions + 1); final byte[][] quantiles = sketch.getQuantiles(numPartitions + 1);
final List<ClusterByPartition> partitions = new ArrayList<>(); final List<ClusterByPartition> partitions = new ArrayList<>();
for (int i = 0; i < numPartitions; i++) { for (int i = 0; i < numPartitions; i++) {
final boolean isFinalPartition = i == numPartitions - 1; final boolean isFinalPartition = i == numPartitions - 1;
if (isFinalPartition) { if (isFinalPartition) {
partitions.add(new ClusterByPartition(quantiles[i], null)); partitions.add(new ClusterByPartition(RowKey.wrap(quantiles[i]), null));
} else { } else {
final ClusterByPartition partition = new ClusterByPartition(quantiles[i], quantiles[i + 1]); final int cmp = comparator.compare(quantiles[i], quantiles[i + 1]);
final int cmp = comparator.compare(partition.getStart(), partition.getEnd());
if (cmp < 0) { if (cmp < 0) {
// Skip partitions where start == end. // Skip partitions where start == end.
// I don't think start can be greater than end, but if that happens, skip them too! // I don't think start can be greater than end, but if that happens, skip them too!
final ClusterByPartition partition = new ClusterByPartition(RowKey.wrap(quantiles[i]), RowKey.wrap(quantiles[i + 1]));
partitions.add(partition); partitions.add(partition);
} }
} }
@ -177,7 +177,7 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
/** /**
* Retrieves the backing sketch. Exists for usage by {@link QuantilesSketchKeyCollectorFactory}. * Retrieves the backing sketch. Exists for usage by {@link QuantilesSketchKeyCollectorFactory}.
*/ */
ItemsSketch<RowKey> getSketch() ItemsSketch<byte[]> getSketch()
{ {
return sketch; return sketch;
} }

View File

@ -28,7 +28,6 @@ import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory; import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.quantiles.ItemsSketch; import org.apache.datasketches.quantiles.ItemsSketch;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import java.io.IOException; import java.io.IOException;
@ -42,16 +41,16 @@ public class QuantilesSketchKeyCollectorFactory
@VisibleForTesting @VisibleForTesting
static final int SKETCH_INITIAL_K = 1 << 15; static final int SKETCH_INITIAL_K = 1 << 15;
private final Comparator<RowKey> comparator; private final Comparator<byte[]> comparator;
private QuantilesSketchKeyCollectorFactory(final Comparator<RowKey> comparator) private QuantilesSketchKeyCollectorFactory(final Comparator<byte[]> comparator)
{ {
this.comparator = comparator; this.comparator = comparator;
} }
static QuantilesSketchKeyCollectorFactory create(final ClusterBy clusterBy) static QuantilesSketchKeyCollectorFactory create(final ClusterBy clusterBy)
{ {
return new QuantilesSketchKeyCollectorFactory(clusterBy.keyComparator()); return new QuantilesSketchKeyCollectorFactory(clusterBy.byteKeyComparator());
} }
@Override @Override
@ -78,7 +77,7 @@ public class QuantilesSketchKeyCollectorFactory
public QuantilesSketchKeyCollectorSnapshot toSnapshot(QuantilesSketchKeyCollector collector) public QuantilesSketchKeyCollectorSnapshot toSnapshot(QuantilesSketchKeyCollector collector)
{ {
final String encodedSketch = final String encodedSketch =
StringUtils.encodeBase64String(collector.getSketch().toByteArray(RowKeySerde.INSTANCE)); StringUtils.encodeBase64String(collector.getSketch().toByteArray(ByteRowKeySerde.INSTANCE));
return new QuantilesSketchKeyCollectorSnapshot(encodedSketch, collector.getAverageKeyLength()); return new QuantilesSketchKeyCollectorSnapshot(encodedSketch, collector.getAverageKeyLength());
} }
@ -87,26 +86,26 @@ public class QuantilesSketchKeyCollectorFactory
{ {
final String encodedSketch = snapshot.getEncodedSketch(); final String encodedSketch = snapshot.getEncodedSketch();
final byte[] bytes = StringUtils.decodeBase64String(encodedSketch); final byte[] bytes = StringUtils.decodeBase64String(encodedSketch);
final ItemsSketch<RowKey> sketch = final ItemsSketch<byte[]> sketch =
ItemsSketch.getInstance(Memory.wrap(bytes), comparator, RowKeySerde.INSTANCE); ItemsSketch.getInstance(Memory.wrap(bytes), comparator, ByteRowKeySerde.INSTANCE);
return new QuantilesSketchKeyCollector(comparator, sketch, snapshot.getAverageKeyLength()); return new QuantilesSketchKeyCollector(comparator, sketch, snapshot.getAverageKeyLength());
} }
private static class RowKeySerde extends ArrayOfItemsSerDe<RowKey> private static class ByteRowKeySerde extends ArrayOfItemsSerDe<byte[]>
{ {
private static final RowKeySerde INSTANCE = new RowKeySerde(); private static final ByteRowKeySerde INSTANCE = new ByteRowKeySerde();
private RowKeySerde() private ByteRowKeySerde()
{ {
} }
@Override @Override
public byte[] serializeToByteArray(final RowKey[] items) public byte[] serializeToByteArray(final byte[][] items)
{ {
int serializedSize = Integer.BYTES * items.length; int serializedSize = Integer.BYTES * items.length;
for (final RowKey key : items) { for (final byte[] key : items) {
serializedSize += key.array().length; serializedSize += key.length;
} }
final byte[] serializedBytes = new byte[serializedSize]; final byte[] serializedBytes = new byte[serializedSize];
@ -114,8 +113,7 @@ public class QuantilesSketchKeyCollectorFactory
long keyWritePosition = (long) Integer.BYTES * items.length; long keyWritePosition = (long) Integer.BYTES * items.length;
for (int i = 0; i < items.length; i++) { for (int i = 0; i < items.length; i++) {
final RowKey key = items[i]; final byte[] keyBytes = items[i];
final byte[] keyBytes = key.array();
writableMemory.putInt((long) Integer.BYTES * i, keyBytes.length); writableMemory.putInt((long) Integer.BYTES * i, keyBytes.length);
writableMemory.putByteArray(keyWritePosition, keyBytes, 0, keyBytes.length); writableMemory.putByteArray(keyWritePosition, keyBytes, 0, keyBytes.length);
@ -128,9 +126,9 @@ public class QuantilesSketchKeyCollectorFactory
} }
@Override @Override
public RowKey[] deserializeFromMemory(final Memory mem, final int numItems) public byte[][] deserializeFromMemory(final Memory mem, final int numItems)
{ {
final RowKey[] keys = new RowKey[numItems]; final byte[][] keys = new byte[numItems][];
long keyPosition = (long) Integer.BYTES * numItems; long keyPosition = (long) Integer.BYTES * numItems;
for (int i = 0; i < numItems; i++) { for (int i = 0; i < numItems; i++) {
@ -138,7 +136,7 @@ public class QuantilesSketchKeyCollectorFactory
final byte[] keyBytes = new byte[keyLength]; final byte[] keyBytes = new byte[keyLength];
mem.getByteArray(keyPosition, keyBytes, 0, keyLength); mem.getByteArray(keyPosition, keyBytes, 0, keyLength);
keys[i] = RowKey.wrap(keyBytes); keys[i] = keyBytes;
keyPosition += keyLength; keyPosition += keyLength;
} }

View File

@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.key;
import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import org.apache.druid.frame.read.FrameReaderUtils;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
/**
* Comparator for byte arrays from {@link RowKey#key} instances.
*
* Comparison logic in this class is very similar to {@link FrameComparisonWidget}, but is different because it works
* on byte[] instead of Frames.
*/
public class ByteRowKeyComparator implements Comparator<byte[]>
{
private final int firstFieldPosition;
private final int[] ascDescRunLengths;
private ByteRowKeyComparator(
final int firstFieldPosition,
final int[] ascDescRunLengths
)
{
this.firstFieldPosition = firstFieldPosition;
this.ascDescRunLengths = ascDescRunLengths;
}
public static ByteRowKeyComparator create(final List<SortColumn> keyColumns)
{
return new ByteRowKeyComparator(
computeFirstFieldPosition(keyColumns.size()),
computeAscDescRunLengths(keyColumns)
);
}
/**
* Compute the offset into each key where the first field starts.
*
* Public so {@link FrameComparisonWidgetImpl} can use it.
*/
public static int computeFirstFieldPosition(final int fieldCount)
{
return Ints.checkedCast((long) fieldCount * Integer.BYTES);
}
/**
* Given a list of sort columns, compute an array of the number of ascending fields in a run, followed by number of
* descending fields in a run, followed by ascending, etc. For example: ASC, ASC, DESC, ASC would return [2, 1, 1]
* and DESC, DESC, ASC would return [0, 2, 1].
*
* Public so {@link FrameComparisonWidgetImpl} can use it.
*/
public static int[] computeAscDescRunLengths(final List<SortColumn> sortColumns)
{
final IntList ascDescRunLengths = new IntArrayList(4);
boolean descending = false;
int runLength = 0;
for (final SortColumn column : sortColumns) {
if (column.descending() != descending) {
ascDescRunLengths.add(runLength);
runLength = 0;
descending = !descending;
}
runLength++;
}
if (runLength > 0) {
ascDescRunLengths.add(runLength);
}
return ascDescRunLengths.toIntArray();
}
@Override
@SuppressWarnings("SubtractionInCompareTo")
public int compare(final byte[] keyArray1, final byte[] keyArray2)
{
// Similar logic to FrameComparaisonWidgetImpl, but implementation is different enough that we need our own.
// Major difference is Frame v. Frame instead of byte[] v. byte[].
int comparableBytesStartPosition1 = firstFieldPosition;
int comparableBytesStartPosition2 = firstFieldPosition;
boolean ascending = true;
int field = 0;
for (int numFields : ascDescRunLengths) {
if (numFields > 0) {
final int nextField = field + numFields;
final int comparableBytesEndPosition1 = RowKeyReader.fieldEndPosition(keyArray1, nextField - 1);
final int comparableBytesEndPosition2 = RowKeyReader.fieldEndPosition(keyArray2, nextField - 1);
int cmp = FrameReaderUtils.compareByteArraysUnsigned(
keyArray1,
comparableBytesStartPosition1,
comparableBytesEndPosition1 - comparableBytesStartPosition1,
keyArray2,
comparableBytesStartPosition2,
comparableBytesEndPosition2 - comparableBytesStartPosition2
);
if (cmp != 0) {
return ascending ? cmp : -cmp;
}
field = nextField;
comparableBytesStartPosition1 = comparableBytesEndPosition1;
comparableBytesStartPosition2 = comparableBytesEndPosition2;
}
ascending = !ascending;
}
return 0;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ByteRowKeyComparator that = (ByteRowKeyComparator) o;
return firstFieldPosition == that.firstFieldPosition
&& Arrays.equals(ascDescRunLengths, that.ascDescRunLengths);
}
@Override
public int hashCode()
{
int result = Objects.hash(firstFieldPosition);
result = 31 * result + Arrays.hashCode(ascDescRunLengths);
return result;
}
@Override
public String toString()
{
return "ByteRowKeyComparator{" +
"firstFieldPosition=" + firstFieldPosition +
", ascDescRunLengths=" + Arrays.toString(ascDescRunLengths) +
'}';
}
}

View File

@ -125,6 +125,14 @@ public class ClusterBy
return RowKeyComparator.create(columns); return RowKeyComparator.create(columns);
} }
/**
* Comparator that compares byte arrays of keys for this instance using the given signature directly.
*/
public Comparator<byte[]> byteKeyComparator()
{
return ByteRowKeyComparator.create(columns);
}
/** /**
* Comparator that compares bucket keys for this instance. Bucket keys are retrieved by calling * Comparator that compares bucket keys for this instance. Bucket keys are retrieved by calling
* {@link RowKeyReader#trim(RowKey, int)} with {@link #getBucketByCount()}. * {@link RowKeyReader#trim(RowKey, int)} with {@link #getBucketByCount()}.

View File

@ -89,8 +89,8 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
frame.region(RowBasedFrameWriter.ROW_OFFSET_REGION), frame.region(RowBasedFrameWriter.ROW_OFFSET_REGION),
frame.region(RowBasedFrameWriter.ROW_DATA_REGION), frame.region(RowBasedFrameWriter.ROW_DATA_REGION),
sortColumns.size(), sortColumns.size(),
RowKeyComparator.computeFirstFieldPosition(frameReader.signature().size()), ByteRowKeyComparator.computeFirstFieldPosition(frameReader.signature().size()),
RowKeyComparator.computeAscDescRunLengths(sortColumns) ByteRowKeyComparator.computeAscDescRunLengths(sortColumns)
); );
} }

View File

@ -32,9 +32,8 @@ public class RowKey
{ {
private static final RowKey EMPTY_KEY = new RowKey(new byte[0]); private static final RowKey EMPTY_KEY = new RowKey(new byte[0]);
// Constant to account for hashcode and object overhead // Constant to account for byte array overhead.
// 24 bytes (header) + 8 bytes (reference) + 8 bytes (hashCode long) + 4 bytes (safe estimate of hashCodeComputed) static final int OBJECT_OVERHEAD_SIZE_BYTES = 24;
static final int OBJECT_OVERHEAD_SIZE_BYTES = 44;
private final byte[] key; private final byte[] key;
@ -114,7 +113,7 @@ public class RowKey
} }
/** /**
* Estimate number of bytes taken by an object of {@link RowKey}. Only returns an estimate and does not account for * Estimate number of bytes taken by the key array. Only returns an estimate and does not account for
* platform or JVM specific implementation. * platform or JVM specific implementation.
*/ */
public int estimatedObjectSizeBytes() public int estimatedObjectSizeBytes()

View File

@ -19,129 +19,32 @@
package org.apache.druid.frame.key; package org.apache.druid.frame.key;
import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import org.apache.druid.frame.read.FrameReaderUtils;
import java.util.Arrays;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Objects;
/** /**
* Comparator for {@link RowKey} instances. * Comparator for {@link RowKey} instances.
* *
* Comparison logic in this class is very similar to {@link FrameComparisonWidget}, but is different because it works * Delegates the comparing to a {@link ByteRowKeyComparator}.
* on byte[] instead of Frames.
*/ */
public class RowKeyComparator implements Comparator<RowKey> public class RowKeyComparator implements Comparator<RowKey>
{ {
private final int firstFieldPosition; private final ByteRowKeyComparator byteRowKeyComparatorDelegate;
private final int[] ascDescRunLengths;
private RowKeyComparator( private RowKeyComparator(final ByteRowKeyComparator byteRowKeyComparatorDelegate)
final int firstFieldPosition,
final int[] ascDescRunLengths
)
{ {
this.firstFieldPosition = firstFieldPosition; this.byteRowKeyComparatorDelegate = byteRowKeyComparatorDelegate;
this.ascDescRunLengths = ascDescRunLengths;
} }
public static RowKeyComparator create(final List<SortColumn> keyColumns) public static RowKeyComparator create(final List<SortColumn> keyColumns)
{ {
return new RowKeyComparator( return new RowKeyComparator(ByteRowKeyComparator.create(keyColumns));
computeFirstFieldPosition(keyColumns.size()),
computeAscDescRunLengths(keyColumns)
);
}
/**
* Compute the offset into each key where the first field starts.
*
* Public so {@link FrameComparisonWidgetImpl} can use it.
*/
public static int computeFirstFieldPosition(final int fieldCount)
{
return Ints.checkedCast((long) fieldCount * Integer.BYTES);
}
/**
* Given a list of sort columns, compute an array of the number of ascending fields in a run, followed by number of
* descending fields in a run, followed by ascending, etc. For example: ASC, ASC, DESC, ASC would return [2, 1, 1]
* and DESC, DESC, ASC would return [0, 2, 1].
*
* Public so {@link FrameComparisonWidgetImpl} can use it.
*/
public static int[] computeAscDescRunLengths(final List<SortColumn> sortColumns)
{
final IntList ascDescRunLengths = new IntArrayList(4);
boolean descending = false;
int runLength = 0;
for (final SortColumn column : sortColumns) {
if (column.descending() != descending) {
ascDescRunLengths.add(runLength);
runLength = 0;
descending = !descending;
}
runLength++;
}
if (runLength > 0) {
ascDescRunLengths.add(runLength);
}
return ascDescRunLengths.toIntArray();
} }
@Override @Override
@SuppressWarnings("SubtractionInCompareTo")
public int compare(final RowKey key1, final RowKey key2) public int compare(final RowKey key1, final RowKey key2)
{ {
// Similar logic to FrameComparaisonWidgetImpl, but implementation is different enough that we need our own. return byteRowKeyComparatorDelegate.compare(key1.array(), key2.array());
// Major difference is Frame v. Frame instead of byte[] v. byte[].
final byte[] keyArray1 = key1.array();
final byte[] keyArray2 = key2.array();
int comparableBytesStartPosition1 = firstFieldPosition;
int comparableBytesStartPosition2 = firstFieldPosition;
boolean ascending = true;
int field = 0;
for (int numFields : ascDescRunLengths) {
if (numFields > 0) {
final int nextField = field + numFields;
final int comparableBytesEndPosition1 = RowKeyReader.fieldEndPosition(keyArray1, nextField - 1);
final int comparableBytesEndPosition2 = RowKeyReader.fieldEndPosition(keyArray2, nextField - 1);
int cmp = FrameReaderUtils.compareByteArraysUnsigned(
keyArray1,
comparableBytesStartPosition1,
comparableBytesEndPosition1 - comparableBytesStartPosition1,
keyArray2,
comparableBytesStartPosition2,
comparableBytesEndPosition2 - comparableBytesStartPosition2
);
if (cmp != 0) {
return ascending ? cmp : -cmp;
}
field = nextField;
comparableBytesStartPosition1 = comparableBytesEndPosition1;
comparableBytesStartPosition2 = comparableBytesEndPosition2;
}
ascending = !ascending;
}
return 0;
} }
@Override @Override
@ -154,24 +57,20 @@ public class RowKeyComparator implements Comparator<RowKey>
return false; return false;
} }
RowKeyComparator that = (RowKeyComparator) o; RowKeyComparator that = (RowKeyComparator) o;
return firstFieldPosition == that.firstFieldPosition return byteRowKeyComparatorDelegate.equals(that.byteRowKeyComparatorDelegate);
&& Arrays.equals(ascDescRunLengths, that.ascDescRunLengths);
} }
@Override @Override
public int hashCode() public int hashCode()
{ {
int result = Objects.hash(firstFieldPosition); return byteRowKeyComparatorDelegate.hashCode();
result = 31 * result + Arrays.hashCode(ascDescRunLengths);
return result;
} }
@Override @Override
public String toString() public String toString()
{ {
return "RowKeyComparator{" + return "RowKeyComparator{" +
"firstFieldPosition=" + firstFieldPosition + "byteRowKeyComparatorDelegate=" + byteRowKeyComparatorDelegate +
", ascDescRunLengths=" + Arrays.toString(ascDescRunLengths) +
'}'; '}';
} }
} }

View File

@ -0,0 +1,182 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.key;
import com.google.common.collect.ImmutableList;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.java.util.common.guava.Comparators;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class ByteRowKeyComparatorTest extends InitializedNullHandlingTest
{
static final RowSignature SIGNATURE =
RowSignature.builder()
.add("1", ColumnType.LONG)
.add("2", ColumnType.STRING)
.add("3", ColumnType.LONG)
.add("4", ColumnType.DOUBLE)
.build();
private static final Object[] OBJECTS1 = new Object[]{-1L, "foo", 2L, -1.2};
private static final Object[] OBJECTS2 = new Object[]{-1L, null, 2L, 1.2d};
private static final Object[] OBJECTS3 = new Object[]{-1L, "bar", 2L, 1.2d};
private static final Object[] OBJECTS4 = new Object[]{-1L, "foo", 2L, 1.2d};
private static final Object[] OBJECTS5 = new Object[]{-1L, "foo", 3L, 1.2d};
private static final Object[] OBJECTS6 = new Object[]{-1L, "foo", 2L, 1.3d};
private static final Object[] OBJECTS7 = new Object[]{1L, "foo", 2L, -1.2d};
static final List<Object[]> ALL_KEY_OBJECTS = Arrays.asList(
OBJECTS1,
OBJECTS2,
OBJECTS3,
OBJECTS4,
OBJECTS5,
OBJECTS6,
OBJECTS7
);
@Test
public void test_compare_AAAA() // AAAA = all ascending
{
final List<SortColumn> sortColumns = ImmutableList.of(
new SortColumn("1", true),
new SortColumn("2", true),
new SortColumn("3", true),
new SortColumn("4", true)
);
Assert.assertEquals(
sortUsingObjectComparator(sortColumns, ALL_KEY_OBJECTS),
sortUsingByteKeyComparator(sortColumns, ALL_KEY_OBJECTS)
);
}
@Test
public void test_compare_DDDD() // DDDD = all descending
{
final List<SortColumn> sortColumns = ImmutableList.of(
new SortColumn("1", false),
new SortColumn("2", false),
new SortColumn("3", false),
new SortColumn("4", false)
);
Assert.assertEquals(
sortUsingObjectComparator(sortColumns, ALL_KEY_OBJECTS),
sortUsingByteKeyComparator(sortColumns, ALL_KEY_OBJECTS)
);
}
@Test
public void test_compare_DAAD()
{
final List<SortColumn> sortColumns = ImmutableList.of(
new SortColumn("1", false),
new SortColumn("2", true),
new SortColumn("3", true),
new SortColumn("4", false)
);
Assert.assertEquals(
sortUsingObjectComparator(sortColumns, ALL_KEY_OBJECTS),
sortUsingByteKeyComparator(sortColumns, ALL_KEY_OBJECTS)
);
}
@Test
public void test_compare_ADDA()
{
final List<SortColumn> sortColumns = ImmutableList.of(
new SortColumn("1", true),
new SortColumn("2", false),
new SortColumn("3", false),
new SortColumn("4", true)
);
Assert.assertEquals(
sortUsingObjectComparator(sortColumns, ALL_KEY_OBJECTS),
sortUsingByteKeyComparator(sortColumns, ALL_KEY_OBJECTS)
);
}
@Test
public void test_compare_DADA()
{
final List<SortColumn> sortColumns = ImmutableList.of(
new SortColumn("1", true),
new SortColumn("2", false),
new SortColumn("3", true),
new SortColumn("4", false)
);
Assert.assertEquals(
sortUsingObjectComparator(sortColumns, ALL_KEY_OBJECTS),
sortUsingByteKeyComparator(sortColumns, ALL_KEY_OBJECTS)
);
}
@Test
public void test_equals()
{
EqualsVerifier.forClass(ByteRowKeyComparator.class)
.usingGetClass()
.verify();
}
private List<RowKey> sortUsingByteKeyComparator(final List<SortColumn> sortColumns, final List<Object[]> objectss)
{
return objectss.stream()
.map(objects -> KeyTestUtils.createKey(SIGNATURE, objects).array())
.sorted(ByteRowKeyComparator.create(sortColumns))
.map(RowKey::wrap)
.collect(Collectors.toList());
}
private List<RowKey> sortUsingObjectComparator(final List<SortColumn> sortColumns, final List<Object[]> objectss)
{
final List<Object[]> sortedObjectssCopy = objectss.stream().sorted(
(o1, o2) -> {
for (int i = 0; i < sortColumns.size(); i++) {
final SortColumn sortColumn = sortColumns.get(i);
//noinspection unchecked, rawtypes
final int cmp = Comparators.<Comparable>naturalNullsFirst()
.compare((Comparable) o1[i], (Comparable) o2[i]);
if (cmp != 0) {
return sortColumn.descending() ? -cmp : cmp;
}
}
return 0;
}
).collect(Collectors.toList());
final List<RowKey> sortedKeys = new ArrayList<>();
for (final Object[] objects : sortedObjectssCopy) {
sortedKeys.add(KeyTestUtils.createKey(SIGNATURE, objects));
}
return sortedKeys;
}
}

View File

@ -138,7 +138,10 @@ public class RowKeyComparatorTest extends InitializedNullHandlingTest
@Test @Test
public void test_equals() public void test_equals()
{ {
EqualsVerifier.forClass(RowKeyComparator.class).usingGetClass().verify(); EqualsVerifier.forClass(RowKeyComparator.class)
.withNonnullFields("byteRowKeyComparatorDelegate")
.usingGetClass()
.verify();
} }
private List<RowKey> sortUsingKeyComparator(final List<SortColumn> sortColumns, final List<Object[]> objectss) private List<RowKey> sortUsingKeyComparator(final List<SortColumn> sortColumns, final List<Object[]> objectss)