From 3e2bb4cf10a18395ed65a9be9ad11fbb315170f1 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Wed, 9 Nov 2022 13:33:01 -0800 Subject: [PATCH] fix front-coded bucket size handling, better validation (#13335) * fix front-coded bucket size handling, better validation * Update FrontCodedIndexedTest.java --- .../druid/segment/data/FrontCodedIndexed.java | 2 +- .../segment/data/FrontCodedIndexedWriter.java | 5 +- .../segment/data/FrontCodedIndexedTest.java | 85 ++++++++++++++++++- 3 files changed, 88 insertions(+), 4 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexed.java b/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexed.java index 6e3af1ad095..d2d6c28d340 100644 --- a/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexed.java +++ b/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexed.java @@ -77,7 +77,7 @@ public final class FrontCodedIndexed implements Indexed final ByteBuffer orderedBuffer = buffer.asReadOnlyBuffer().order(ordering); final byte version = orderedBuffer.get(); Preconditions.checkArgument(version == 0, "only V0 exists, encountered " + version); - final int bucketSize = orderedBuffer.get(); + final int bucketSize = Byte.toUnsignedInt(orderedBuffer.get()); final boolean hasNull = NullHandling.IS_NULL_BYTE == orderedBuffer.get(); final int numValues = VByte.readInt(orderedBuffer); // size of offsets + values diff --git a/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexedWriter.java b/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexedWriter.java index d86fca6f6bb..b6120d6c123 100644 --- a/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexedWriter.java +++ b/processing/src/main/java/org/apache/druid/segment/data/FrontCodedIndexedWriter.java @@ -22,6 +22,7 @@ package org.apache.druid.segment.data; import com.google.common.primitives.Ints; import org.apache.druid.common.config.NullHandling; import org.apache.druid.io.Channels; +import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.io.smoosh.FileSmoosher; @@ -76,8 +77,8 @@ public class FrontCodedIndexedWriter implements DictionaryWriter int bucketSize ) { - if (Integer.bitCount(bucketSize) != 1) { - throw new ISE("bucketSize must be a power of two but was[%,d]", bucketSize); + if (Integer.bitCount(bucketSize) != 1 || bucketSize < 1 || bucketSize > 128) { + throw new IAE("bucketSize must be a power of two (from 1 up to 128) but was[%,d]", bucketSize); } this.segmentWriteOutMedium = segmentWriteOutMedium; this.scratch = ByteBuffer.allocate(1 << logScratchSize).order(byteOrder); diff --git a/processing/src/test/java/org/apache/druid/segment/data/FrontCodedIndexedTest.java b/processing/src/test/java/org/apache/druid/segment/data/FrontCodedIndexedTest.java index 7480f7b9e12..f1bd478c954 100644 --- a/processing/src/test/java/org/apache/druid/segment/data/FrontCodedIndexedTest.java +++ b/processing/src/test/java/org/apache/druid/segment/data/FrontCodedIndexedTest.java @@ -21,6 +21,7 @@ package org.apache.druid.segment.data; import com.google.common.collect.ImmutableList; import org.apache.druid.common.utils.IdUtils; +import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.segment.writeout.OnHeapMemorySegmentWriteOutMedium; import org.apache.druid.testing.InitializedNullHandlingTest; @@ -162,7 +163,7 @@ public class FrontCodedIndexedTest extends InitializedNullHandlingTest for (int i = 0; i < sizeBase + sizeAdjust; i++) { values.add(IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId()); } - fillBuffer(buffer, values, 4); + fillBuffer(buffer, values, bucketSize); FrontCodedIndexed codedUtf8Indexed = FrontCodedIndexed.read( buffer, @@ -290,6 +291,88 @@ public class FrontCodedIndexedTest extends InitializedNullHandlingTest Assert.assertFalse(utf8Iterator.hasNext()); } + @Test + public void testBucketSizes() throws IOException + { + final int numValues = 10000; + final ByteBuffer buffer = ByteBuffer.allocate(1 << 24).order(order); + final int[] bucketSizes = new int[] { + 1, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7 + }; + + TreeSet values = new TreeSet<>(GenericIndexed.STRING_STRATEGY); + values.add(null); + for (int i = 0; i < numValues; i++) { + values.add(IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId()); + } + for (int bucketSize : bucketSizes) { + fillBuffer(buffer, values, bucketSize); + FrontCodedIndexed codedUtf8Indexed = FrontCodedIndexed.read( + buffer, + buffer.order() + ).get(); + + Iterator newListIterator = values.iterator(); + Iterator utf8Iterator = codedUtf8Indexed.iterator(); + int ctr = 0; + while (utf8Iterator.hasNext() && newListIterator.hasNext()) { + final String next = newListIterator.next(); + final ByteBuffer nextUtf8 = utf8Iterator.next(); + if (next == null) { + Assert.assertNull(nextUtf8); + } else { + Assert.assertEquals(next, StringUtils.fromUtf8(nextUtf8)); + nextUtf8.position(0); + Assert.assertEquals(next, StringUtils.fromUtf8(codedUtf8Indexed.get(ctr))); + } + Assert.assertEquals(ctr, codedUtf8Indexed.indexOf(nextUtf8)); + ctr++; + } + Assert.assertEquals(newListIterator.hasNext(), utf8Iterator.hasNext()); + Assert.assertEquals(ctr, numValues + 1); + } + } + + @Test + public void testBadBucketSize() + { + OnHeapMemorySegmentWriteOutMedium medium = new OnHeapMemorySegmentWriteOutMedium(); + + Assert.assertThrows( + IAE.class, + () -> new FrontCodedIndexedWriter( + medium, + ByteOrder.nativeOrder(), + 0 + ) + ); + + Assert.assertThrows( + IAE.class, + () -> new FrontCodedIndexedWriter( + medium, + ByteOrder.nativeOrder(), + 15 + ) + ); + + Assert.assertThrows( + IAE.class, + () -> new FrontCodedIndexedWriter( + medium, + ByteOrder.nativeOrder(), + 256 + ) + ); + } + private static long fillBuffer(ByteBuffer buffer, Iterable sortedIterable, int bucketSize) throws IOException { Iterator sortedStrings = sortedIterable.iterator();