fix front-coded bucket size handling, better validation (#13335)

* fix front-coded bucket size handling, better validation

* Update FrontCodedIndexedTest.java
This commit is contained in:
Clint Wylie 2022-11-09 13:33:01 -08:00 committed by GitHub
parent a2013e6566
commit 3e2bb4cf10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 4 deletions

View File

@ -77,7 +77,7 @@ public final class FrontCodedIndexed implements Indexed<ByteBuffer>
final ByteBuffer orderedBuffer = buffer.asReadOnlyBuffer().order(ordering);
final byte version = orderedBuffer.get();
Preconditions.checkArgument(version == 0, "only V0 exists, encountered " + version);
final int bucketSize = orderedBuffer.get();
final int bucketSize = Byte.toUnsignedInt(orderedBuffer.get());
final boolean hasNull = NullHandling.IS_NULL_BYTE == orderedBuffer.get();
final int numValues = VByte.readInt(orderedBuffer);
// size of offsets + values

View File

@ -22,6 +22,7 @@ package org.apache.druid.segment.data;
import com.google.common.primitives.Ints;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.io.Channels;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.io.smoosh.FileSmoosher;
@ -76,8 +77,8 @@ public class FrontCodedIndexedWriter implements DictionaryWriter<byte[]>
int bucketSize
)
{
if (Integer.bitCount(bucketSize) != 1) {
throw new ISE("bucketSize must be a power of two but was[%,d]", bucketSize);
if (Integer.bitCount(bucketSize) != 1 || bucketSize < 1 || bucketSize > 128) {
throw new IAE("bucketSize must be a power of two (from 1 up to 128) but was[%,d]", bucketSize);
}
this.segmentWriteOutMedium = segmentWriteOutMedium;
this.scratch = ByteBuffer.allocate(1 << logScratchSize).order(byteOrder);

View File

@ -21,6 +21,7 @@ package org.apache.druid.segment.data;
import com.google.common.collect.ImmutableList;
import org.apache.druid.common.utils.IdUtils;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.writeout.OnHeapMemorySegmentWriteOutMedium;
import org.apache.druid.testing.InitializedNullHandlingTest;
@ -162,7 +163,7 @@ public class FrontCodedIndexedTest extends InitializedNullHandlingTest
for (int i = 0; i < sizeBase + sizeAdjust; i++) {
values.add(IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId());
}
fillBuffer(buffer, values, 4);
fillBuffer(buffer, values, bucketSize);
FrontCodedIndexed codedUtf8Indexed = FrontCodedIndexed.read(
buffer,
@ -290,6 +291,88 @@ public class FrontCodedIndexedTest extends InitializedNullHandlingTest
Assert.assertFalse(utf8Iterator.hasNext());
}
@Test
public void testBucketSizes() throws IOException
{
final int numValues = 10000;
final ByteBuffer buffer = ByteBuffer.allocate(1 << 24).order(order);
final int[] bucketSizes = new int[] {
1,
1 << 1,
1 << 2,
1 << 3,
1 << 4,
1 << 5,
1 << 6,
1 << 7
};
TreeSet<String> values = new TreeSet<>(GenericIndexed.STRING_STRATEGY);
values.add(null);
for (int i = 0; i < numValues; i++) {
values.add(IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId() + IdUtils.getRandomId());
}
for (int bucketSize : bucketSizes) {
fillBuffer(buffer, values, bucketSize);
FrontCodedIndexed codedUtf8Indexed = FrontCodedIndexed.read(
buffer,
buffer.order()
).get();
Iterator<String> newListIterator = values.iterator();
Iterator<ByteBuffer> utf8Iterator = codedUtf8Indexed.iterator();
int ctr = 0;
while (utf8Iterator.hasNext() && newListIterator.hasNext()) {
final String next = newListIterator.next();
final ByteBuffer nextUtf8 = utf8Iterator.next();
if (next == null) {
Assert.assertNull(nextUtf8);
} else {
Assert.assertEquals(next, StringUtils.fromUtf8(nextUtf8));
nextUtf8.position(0);
Assert.assertEquals(next, StringUtils.fromUtf8(codedUtf8Indexed.get(ctr)));
}
Assert.assertEquals(ctr, codedUtf8Indexed.indexOf(nextUtf8));
ctr++;
}
Assert.assertEquals(newListIterator.hasNext(), utf8Iterator.hasNext());
Assert.assertEquals(ctr, numValues + 1);
}
}
@Test
public void testBadBucketSize()
{
OnHeapMemorySegmentWriteOutMedium medium = new OnHeapMemorySegmentWriteOutMedium();
Assert.assertThrows(
IAE.class,
() -> new FrontCodedIndexedWriter(
medium,
ByteOrder.nativeOrder(),
0
)
);
Assert.assertThrows(
IAE.class,
() -> new FrontCodedIndexedWriter(
medium,
ByteOrder.nativeOrder(),
15
)
);
Assert.assertThrows(
IAE.class,
() -> new FrontCodedIndexedWriter(
medium,
ByteOrder.nativeOrder(),
256
)
);
}
private static long fillBuffer(ByteBuffer buffer, Iterable<String> sortedIterable, int bucketSize) throws IOException
{
Iterator<String> sortedStrings = sortedIterable.iterator();