When removeNullBytes is set, length calculations did not take into account null bytes. (#17232)

* When replaceNullBytes is set, length calculations did not take into account null bytes.
This commit is contained in:
Karan Kumar 2024-10-07 18:02:52 +05:30 committed by GitHub
parent c9201ad658
commit 6a4352f466
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 23 deletions

View File

@ -128,14 +128,14 @@ public class StringFieldWriter implements FieldWriter
written++; written++;
if (len > 0) { if (len > 0) {
FrameWriterUtils.copyByteBufferToMemoryDisallowingNullBytes( int lenWritten = FrameWriterUtils.copyByteBufferToMemoryDisallowingNullBytes(
utf8Datum, utf8Datum,
memory, memory,
position + written, position + written,
len, len,
removeNullBytes removeNullBytes
); );
written += len; written += lenWritten;
} }
} }

View File

@ -212,9 +212,11 @@ public class FrameWriterUtils
/** /**
* Copies {@code src} to {@code dst}, disallowing null bytes to be written to the destination. If {@code removeNullBytes} * Copies {@code src} to {@code dst}, disallowing null bytes to be written to the destination. If {@code removeNullBytes}
* is true, the method will drop the null bytes, and if it is false, the method will throw an exception. * is true, the method will drop the null bytes, and if it is false, the method will throw an exception. The written bytes
* can be less than "len" if the null bytes are dropped, and the callers must evaluate the return value to see the actual
* length of the buffer that is copied
*/ */
public static void copyByteBufferToMemoryDisallowingNullBytes( public static int copyByteBufferToMemoryDisallowingNullBytes(
final ByteBuffer src, final ByteBuffer src,
final WritableMemory dst, final WritableMemory dst,
final long dstPosition, final long dstPosition,
@ -222,11 +224,16 @@ public class FrameWriterUtils
final boolean removeNullBytes final boolean removeNullBytes
) )
{ {
copyByteBufferToMemory(src, dst, dstPosition, len, false, removeNullBytes); return copyByteBufferToMemory(src, dst, dstPosition, len, false, removeNullBytes);
} }
/** /**
* Copies "len" bytes from {@code src.position()} to "dstPosition" in "memory". Does not update the position of src. * Tries to copy "len" bytes from {@code src.position()} to "dstPosition" in "memory". If removeNullBytes is set to true,
* it will remove the U+0000 bytes from the src buffer, and the written bytes will be less than "len". It is imperative that the
* callers check the number of written bytes when "removeNullBytes" can be set to true, i.e. this method is invoked via
* {@link #copyByteBufferToMemoryDisallowingNullBytes}
* <p>
* Does not update the position of src.
* <p> * <p>
* Whenever "allowNullBytes" is true, "removeNullBytes" must be false. Use the methods {@link #copyByteBufferToMemoryAllowingNullBytes} * Whenever "allowNullBytes" is true, "removeNullBytes" must be false. Use the methods {@link #copyByteBufferToMemoryAllowingNullBytes}
* and {@link #copyByteBufferToMemoryDisallowingNullBytes} to copy between the memory * and {@link #copyByteBufferToMemoryDisallowingNullBytes} to copy between the memory
@ -234,7 +241,7 @@ public class FrameWriterUtils
* *
* @throws InvalidNullByteException if "allowNullBytes" and "removeNullBytes" is false and a null byte is encountered * @throws InvalidNullByteException if "allowNullBytes" and "removeNullBytes" is false and a null byte is encountered
*/ */
private static void copyByteBufferToMemory( private static int copyByteBufferToMemory(
final ByteBuffer src, final ByteBuffer src,
final WritableMemory dst, final WritableMemory dst,
final long dstPosition, final long dstPosition,
@ -251,6 +258,7 @@ public class FrameWriterUtils
} }
final int srcEnd = src.position() + len; final int srcEnd = src.position() + len;
int writtenLength = 0;
if (allowNullBytes) { if (allowNullBytes) {
if (src.hasArray()) { if (src.hasArray()) {
@ -264,6 +272,8 @@ public class FrameWriterUtils
dst.putByte(q, b); dst.putByte(q, b);
} }
} }
// The method does not alter the length of the memory copied if null bytes are allowed
writtenLength = len;
} else { } else {
long q = dstPosition; long q = dstPosition;
for (int p = src.position(); p < srcEnd; p++) { for (int p = src.position(); p < srcEnd; p++) {
@ -282,9 +292,11 @@ public class FrameWriterUtils
} else { } else {
dst.putByte(q, b); dst.putByte(q, b);
q++; q++;
writtenLength++;
} }
} }
} }
return writtenLength;
} }
/** /**

View File

@ -21,6 +21,7 @@ package org.apache.druid.frame.field;
import org.apache.datasketches.memory.WritableMemory; import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.write.InvalidNullByteException;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
@ -40,9 +41,11 @@ import org.mockito.junit.MockitoRule;
import org.mockito.quality.Strictness; import org.mockito.quality.Strictness;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class StringFieldWriterTest extends InitializedNullHandlingTest public class StringFieldWriterTest extends InitializedNullHandlingTest
{ {
@ -57,9 +60,12 @@ public class StringFieldWriterTest extends InitializedNullHandlingTest
@Mock @Mock
public DimensionSelector selectorUtf8; public DimensionSelector selectorUtf8;
private WritableMemory memory; private WritableMemory memory;
private FieldWriter fieldWriter; private FieldWriter fieldWriter;
private FieldWriter fieldWriterUtf8; private FieldWriter fieldWriterUtf8;
private FieldWriter fieldWriterRemoveNull;
private FieldWriter fieldWriterUtf8RemoveNull;
@Before @Before
public void setUp() public void setUp()
@ -67,13 +73,32 @@ public class StringFieldWriterTest extends InitializedNullHandlingTest
memory = WritableMemory.allocate(1000); memory = WritableMemory.allocate(1000);
fieldWriter = new StringFieldWriter(selector, false); fieldWriter = new StringFieldWriter(selector, false);
fieldWriterUtf8 = new StringFieldWriter(selectorUtf8, false); fieldWriterUtf8 = new StringFieldWriter(selectorUtf8, false);
fieldWriterRemoveNull = new StringFieldWriter(selector, true);
fieldWriterUtf8RemoveNull = new StringFieldWriter(selectorUtf8, true);
} }
@After @After
public void tearDown() public void tearDown()
{ {
fieldWriter.close(); for (FieldWriter fw : getFieldWriter(FieldWritersType.ALL)) {
fieldWriterUtf8.close(); try {
fw.close();
}
catch (Exception ignore) {
}
}
}
private List<FieldWriter> getFieldWriter(FieldWritersType fieldWritersType)
{
if (fieldWritersType == FieldWritersType.NULL_REPLACING) {
return Arrays.asList(fieldWriterRemoveNull, fieldWriterUtf8RemoveNull);
} else if (fieldWritersType == FieldWritersType.ALL) {
return Arrays.asList(fieldWriter, fieldWriterUtf8, fieldWriterRemoveNull, fieldWriterUtf8RemoveNull);
} else {
throw new ISE("Handler missing for type:[%s]", fieldWritersType);
}
} }
@Test @Test
@ -100,31 +125,63 @@ public class StringFieldWriterTest extends InitializedNullHandlingTest
doTest(Arrays.asList("foo", "bar")); doTest(Arrays.asList("foo", "bar"));
} }
@Test @Test
public void testMultiValueStringContainingNulls() public void testMultiValueStringContainingNulls()
{ {
doTest(Arrays.asList("foo", NullHandling.emptyToNullIfNeeded(""), "bar", null)); doTest(Arrays.asList("foo", NullHandling.emptyToNullIfNeeded(""), "bar", null));
} }
@Test
public void testNullByteReplacement()
{
doTest(
Arrays.asList("abc\u0000", "foo" + NullHandling.emptyToNullIfNeeded("") + "bar", "def"),
FieldWritersType.NULL_REPLACING
);
}
@Test
public void testNullByteNotReplaced()
{
mockSelectors(Arrays.asList("abc\u0000", "foo" + NullHandling.emptyToNullIfNeeded("") + "bar", "def"));
Assert.assertThrows(InvalidNullByteException.class, () -> {
doTestWithSpecificFieldWriter(fieldWriter);
});
Assert.assertThrows(InvalidNullByteException.class, () -> {
doTestWithSpecificFieldWriter(fieldWriterUtf8);
});
}
private void doTest(final List<String> values) private void doTest(final List<String> values)
{
doTest(values, FieldWritersType.ALL);
}
private void doTest(final List<String> values, FieldWritersType fieldWritersType)
{ {
mockSelectors(values); mockSelectors(values);
// Non-UTF8 test List<FieldWriter> fieldWriters = getFieldWriter(fieldWritersType);
{ for (FieldWriter fw : fieldWriters) {
final long written = writeToMemory(fieldWriter); final Object[] valuesRead = doTestWithSpecificFieldWriter(fw);
final Object[] valuesRead = readFromMemory(written); List<String> expectedResults = new ArrayList<>(values);
Assert.assertEquals("values read (non-UTF8)", values, Arrays.asList(valuesRead)); if (fieldWritersType == FieldWritersType.NULL_REPLACING) {
expectedResults = expectedResults.stream()
.map(val -> StringUtils.replace(val, "\u0000", ""))
.collect(Collectors.toList());
}
Assert.assertEquals("values read", expectedResults, Arrays.asList(valuesRead));
}
} }
// UTF8 test private Object[] doTestWithSpecificFieldWriter(FieldWriter fieldWriter)
{ {
final long writtenUtf8 = writeToMemory(fieldWriterUtf8); final long written = writeToMemory(fieldWriter);
final Object[] valuesReadUtf8 = readFromMemory(writtenUtf8); return readFromMemory(written);
Assert.assertEquals("values read (UTF8)", values, Arrays.asList(valuesReadUtf8));
}
} }
private void mockSelectors(final List<String> values) private void mockSelectors(final List<String> values)
{ {
final RangeIndexedInts row = new RangeIndexedInts(); final RangeIndexedInts row = new RangeIndexedInts();
@ -183,9 +240,20 @@ public class StringFieldWriterTest extends InitializedNullHandlingTest
memory.getByteArray(MEMORY_POSITION, bytes, 0, (int) written); memory.getByteArray(MEMORY_POSITION, bytes, 0, (int) written);
final FieldReader fieldReader = FieldReaders.create("columnNameDoesntMatterHere", ColumnType.STRING_ARRAY); final FieldReader fieldReader = FieldReaders.create("columnNameDoesntMatterHere", ColumnType.STRING_ARRAY);
final ColumnValueSelector<?> selector = final ColumnValueSelector<?> selector = fieldReader.makeColumnValueSelector(
fieldReader.makeColumnValueSelector(memory, new ConstantFieldPointer(MEMORY_POSITION, -1)); memory,
new ConstantFieldPointer(
MEMORY_POSITION,
-1
)
);
return (Object[]) selector.getObject(); return (Object[]) selector.getObject();
} }
private enum FieldWritersType
{
NULL_REPLACING, // include null replacing writers only
ALL // include all writers
}
} }