From b78d35598c9e96d9237c330e1feede11c390aa9e Mon Sep 17 00:00:00 2001 From: chenglei Date: Tue, 7 Sep 2021 22:25:04 +0800 Subject: [PATCH] HBASE-26197 Fix some obvious bugs in MultiByteBuff.put (#3586) Signed-off-by: stack Signed-off-by: Duo Zhang --- .../hadoop/hbase/nio/MultiByteBuff.java | 119 ++++++++++++++---- .../hadoop/hbase/nio/TestMultiByteBuff.java | 112 +++++++++++++++++ 2 files changed, 208 insertions(+), 23 deletions(-) diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/nio/MultiByteBuff.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/nio/MultiByteBuff.java index df0ae8eaadc..a25791e5ce8 100644 --- a/hbase-common/src/main/java/org/apache/hadoop/hbase/nio/MultiByteBuff.java +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/nio/MultiByteBuff.java @@ -736,52 +736,125 @@ public class MultiByteBuff extends ByteBuff { } /** - * Copies from a src MBB to this MBB. - * @param offset the position in this MBB to which the copy should happen + * Copies from a src BB to this MBB. This will be absolute positional copying and won't affect the + * position of any of the buffers. + * @param destOffset the position in this MBB to which the copy should happen * @param src the src MBB * @param srcOffset the offset in the src MBB from where the elements should be read * @param length the length upto which the copy should happen + * @throws BufferUnderflowException If there are fewer than length bytes remaining in src + * ByteBuff. + * @throws BufferOverflowException If there is insufficient available space in this MBB for length + * bytes. */ @Override - public MultiByteBuff put(int offset, ByteBuff src, int srcOffset, int length) { + public MultiByteBuff put(int destOffset, ByteBuff src, int srcOffset, int length) { checkRefCount(); - int destItemIndex = getItemIndex(offset); - int srcItemIndex = getItemIndex(srcOffset); + int destItemIndex = getItemIndex(destOffset); + int srcItemIndex = getItemIndexForByteBuff(src, srcOffset, length); + ByteBuffer destItem = this.items[destItemIndex]; - offset = offset - this.itemBeginPos[destItemIndex]; + destOffset = this.getRelativeOffset(destOffset, destItemIndex); ByteBuffer srcItem = getItemByteBuffer(src, srcItemIndex); - srcOffset = srcOffset - this.itemBeginPos[srcItemIndex]; - int toRead, toWrite, toMove; + srcOffset = getRelativeOffsetForByteBuff(src, srcOffset, srcItemIndex); + while (length > 0) { - toWrite = destItem.limit() - offset; - toRead = srcItem.limit() - srcOffset; - toMove = Math.min(length, Math.min(toRead, toWrite)); - ByteBufferUtils.copyFromBufferToBuffer(srcItem, destItem, srcOffset, offset, toMove); + int toWrite = destItem.limit() - destOffset; + if (toWrite <= 0) { + throw new BufferOverflowException(); + } + int toRead = srcItem.limit() - srcOffset; + if (toRead <= 0) { + throw new BufferUnderflowException(); + } + int toMove = Math.min(length, Math.min(toRead, toWrite)); + ByteBufferUtils.copyFromBufferToBuffer(srcItem, destItem, srcOffset, destOffset, toMove); length -= toMove; - if (length == 0) break; + if (length == 0) { + break; + } if (toRead < toWrite) { - srcItem = getItemByteBuffer(src, ++srcItemIndex); + if (++srcItemIndex >= getItemByteBufferCount(src)) { + throw new BufferUnderflowException(); + } + srcItem = getItemByteBuffer(src, srcItemIndex); srcOffset = 0; - offset += toMove; + destOffset += toMove; } else if (toRead > toWrite) { - destItem = this.items[++destItemIndex]; - offset = 0; + if (++destItemIndex >= this.items.length) { + throw new BufferOverflowException(); + } + destItem = this.items[destItemIndex]; + destOffset = 0; srcOffset += toMove; } else { // toRead = toWrite case - srcItem = getItemByteBuffer(src, ++srcItemIndex); + if (++srcItemIndex >= getItemByteBufferCount(src)) { + throw new BufferUnderflowException(); + } + srcItem = getItemByteBuffer(src, srcItemIndex); srcOffset = 0; - destItem = this.items[++destItemIndex]; - offset = 0; + if (++destItemIndex >= this.items.length) { + throw new BufferOverflowException(); + } + destItem = this.items[destItemIndex]; + destOffset = 0; } } return this; } - private static ByteBuffer getItemByteBuffer(ByteBuff buf, int index) { - return (buf instanceof SingleByteBuff) ? buf.nioByteBuffers()[0] - : ((MultiByteBuff) buf).items[index]; + private static ByteBuffer getItemByteBuffer(ByteBuff buf, int byteBufferIndex) { + if (buf instanceof SingleByteBuff) { + if (byteBufferIndex != 0) { + throw new IndexOutOfBoundsException( + "index:[" + byteBufferIndex + "],but only index 0 is valid."); + } + return buf.nioByteBuffers()[0]; + } + MultiByteBuff multiByteBuff = (MultiByteBuff) buf; + if (byteBufferIndex < 0 || byteBufferIndex >= multiByteBuff.items.length) { + throw new IndexOutOfBoundsException( + "index:[" + byteBufferIndex + "],but only index [0-" + multiByteBuff.items.length + + ") is valid."); + } + return multiByteBuff.items[byteBufferIndex]; + } + + private static int getItemIndexForByteBuff(ByteBuff byteBuff, int offset, int length) { + if (byteBuff instanceof SingleByteBuff) { + ByteBuffer byteBuffer = byteBuff.nioByteBuffers()[0]; + if (offset + length > byteBuffer.limit()) { + throw new BufferUnderflowException(); + } + return 0; + } + MultiByteBuff multiByteBuff = (MultiByteBuff) byteBuff; + return multiByteBuff.getItemIndex(offset); + } + + private static int getRelativeOffsetForByteBuff(ByteBuff byteBuff, int globalOffset, + int itemIndex) { + if (byteBuff instanceof SingleByteBuff) { + if (itemIndex != 0) { + throw new IndexOutOfBoundsException("index:[" + itemIndex + "],but only index 0 is valid."); + } + return globalOffset; + } + return ((MultiByteBuff) byteBuff).getRelativeOffset(globalOffset, itemIndex); + } + + private int getRelativeOffset(int globalOffset, int itemIndex) { + if (itemIndex < 0 || itemIndex >= this.items.length) { + throw new IndexOutOfBoundsException( + "index:[" + itemIndex + "],but only index [0-" + this.items.length + ") is valid."); + } + return globalOffset - this.itemBeginPos[itemIndex]; + } + + private static int getItemByteBufferCount(ByteBuff buf) { + return (buf instanceof SingleByteBuff) ? 1 : ((MultiByteBuff) buf).items.length; } /** diff --git a/hbase-common/src/test/java/org/apache/hadoop/hbase/nio/TestMultiByteBuff.java b/hbase-common/src/test/java/org/apache/hadoop/hbase/nio/TestMultiByteBuff.java index 74d09405b4a..563f82a1cd9 100644 --- a/hbase-common/src/test/java/org/apache/hadoop/hbase/nio/TestMultiByteBuff.java +++ b/hbase-common/src/test/java/org/apache/hadoop/hbase/nio/TestMultiByteBuff.java @@ -25,6 +25,7 @@ import static org.junit.Assert.fail; import java.io.IOException; import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import org.apache.hadoop.hbase.HBaseClassTestRule; import org.apache.hadoop.hbase.testclassification.MiscTests; @@ -482,4 +483,115 @@ public class TestMultiByteBuff { assertEquals(out.position(), 12); assertTrue(Bytes.equals(Bytes.toBytes("abcdekabcdef"), 0, 12, out.array(), 0, 12)); } + + @Test + public void testPositionalPutByteBuff() throws Exception { + ByteBuffer bb1 = ByteBuffer.allocate(100); + ByteBuffer bb2 = ByteBuffer.allocate(100); + MultiByteBuff srcMultiByteBuff = new MultiByteBuff(bb1, bb2); + for (int i = 0; i < 25; i++) { + srcMultiByteBuff.putLong(i * 8L); + } + // Test MultiByteBuff To MultiByteBuff + doTestPositionalPutByteBuff(srcMultiByteBuff); + + ByteBuffer bb3 = ByteBuffer.allocate(200); + SingleByteBuff srcSingleByteBuff = new SingleByteBuff(bb3); + for (int i = 0; i < 25; i++) { + srcSingleByteBuff.putLong(i * 8L); + } + // Test SingleByteBuff To MultiByteBuff + doTestPositionalPutByteBuff(srcSingleByteBuff); + } + + private void doTestPositionalPutByteBuff(ByteBuff srcByteBuff) throws Exception { + ByteBuffer bb3 = ByteBuffer.allocate(50); + ByteBuffer bb4 = ByteBuffer.allocate(50); + ByteBuffer bb5 = ByteBuffer.allocate(50); + ByteBuffer bb6 = ByteBuffer.allocate(50); + MultiByteBuff destMultiByteBuff = new MultiByteBuff(bb3, bb4, bb5, bb6); + + // full copy + destMultiByteBuff.put(0, srcByteBuff, 0, 200); + int compareTo = ByteBuff.compareTo(srcByteBuff, 0, 200, destMultiByteBuff, 0, 200); + assertTrue(compareTo == 0); + + // Test src to dest first ByteBuffer + destMultiByteBuff.put(0, srcByteBuff, 32, 63); + compareTo = ByteBuff.compareTo(srcByteBuff, 32, 63, destMultiByteBuff, 0, 63); + assertTrue(compareTo == 0); + + // Test src to dest first and second ByteBuffer + destMultiByteBuff.put(0, srcByteBuff, 0, 63); + compareTo = ByteBuff.compareTo(srcByteBuff, 0, 63, destMultiByteBuff, 0, 63); + assertTrue(compareTo == 0); + + // Test src to dest third ByteBuffer + destMultiByteBuff.put(100, srcByteBuff, 100, 50); + compareTo = ByteBuff.compareTo(srcByteBuff, 100, 50, destMultiByteBuff, 100, 50); + assertTrue(compareTo == 0); + + // Test src to dest first,second and third ByteBuffer + destMultiByteBuff.put(48, srcByteBuff, 32, 63); + compareTo = ByteBuff.compareTo(srcByteBuff, 32, 63, destMultiByteBuff, 48, 63); + assertTrue(compareTo == 0); + + // Test src to dest first,second,third and fourth ByteBuffer + destMultiByteBuff.put(48, srcByteBuff, 32, 120); + compareTo = ByteBuff.compareTo(srcByteBuff, 32, 120, destMultiByteBuff, 48, 120); + assertTrue(compareTo == 0); + + // Test src to dest first and second ByteBuffer + destMultiByteBuff.put(0, srcByteBuff, 132, 63); + compareTo = ByteBuff.compareTo(srcByteBuff, 132, 63, destMultiByteBuff, 0, 63); + assertTrue(compareTo == 0); + + // Test src to dest second,third and fourth ByteBuffer + destMultiByteBuff.put(95, srcByteBuff, 132, 67); + compareTo = ByteBuff.compareTo(srcByteBuff, 132, 67, destMultiByteBuff, 95, 67); + assertTrue(compareTo == 0); + + // Test src to dest fourth ByteBuffer + destMultiByteBuff.put(162, srcByteBuff, 132, 24); + compareTo = ByteBuff.compareTo(srcByteBuff, 132, 24, destMultiByteBuff, 162, 24); + assertTrue(compareTo == 0); + + // Test src BufferUnderflowException + try { + destMultiByteBuff.put(0, srcByteBuff, 0, 300); + fail(); + } catch (BufferUnderflowException e) { + assertTrue(e != null); + } + + try { + destMultiByteBuff.put(95, srcByteBuff, 132, 89); + fail(); + } catch (BufferUnderflowException e) { + assertTrue(e != null); + } + + // Test dest BufferOverflowException + try { + destMultiByteBuff.put(100, srcByteBuff, 0, 101); + fail(); + } catch (BufferOverflowException e) { + assertTrue(e != null); + } + + try { + destMultiByteBuff.put(151, srcByteBuff, 132, 68); + fail(); + } catch (BufferOverflowException e) { + assertTrue(e != null); + } + + destMultiByteBuff = new MultiByteBuff(bb3, bb4); + try { + destMultiByteBuff.put(0, srcByteBuff, 0, 101); + fail(); + } catch (BufferOverflowException e) { + assertTrue(e != null); + } + } }