diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/KeyValue.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/KeyValue.java index a3847739fca..ce7aa372f4e 100644 --- a/hbase-common/src/main/java/org/apache/hadoop/hbase/KeyValue.java +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/KeyValue.java @@ -2759,7 +2759,12 @@ public class KeyValue implements Cell, HeapSize, Cloneable { * @throws IOException */ public static KeyValue create(int length, final DataInput in) throws IOException { - if (length == 0) return null; + + if (length <= 0) { + if (length == 0) return null; + throw new IOException("Failed read " + length + " bytes, stream corrupt?"); + } + // This is how the old Writables.readFrom used to deserialize. Didn't even vint. byte [] bytes = new byte[length]; in.readFully(bytes); diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/TestSerialization.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/TestSerialization.java index 5f54f2a09de..d7082ed370d 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/TestSerialization.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/TestSerialization.java @@ -22,11 +22,13 @@ package org.apache.hadoop.hbase; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.NavigableSet; @@ -74,7 +76,48 @@ public class TestSerialization { assertEquals(kv.getOffset(), deserializedKv.getOffset()); assertEquals(kv.getLength(), deserializedKv.getLength()); } - + + @Test public void testCreateKeyValueInvalidNegativeLength() { + + KeyValue kv_0 = new KeyValue(Bytes.toBytes("myRow"), Bytes.toBytes("myCF"), // 51 bytes + Bytes.toBytes("myQualifier"), 12345L, Bytes.toBytes("my12345")); + + KeyValue kv_1 = new KeyValue(Bytes.toBytes("myRow"), Bytes.toBytes("myCF"), // 49 bytes + Bytes.toBytes("myQualifier"), 12345L, Bytes.toBytes("my123")); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + long l = 0; + try { + l = KeyValue.oswrite(kv_0, dos, false); + l += KeyValue.oswrite(kv_1, dos, false); + assertEquals(100L, l); + } catch (IOException e) { + fail("Unexpected IOException" + e.getMessage()); + } + + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + DataInputStream dis = new DataInputStream(bais); + + try { + KeyValue.create(dis); + assertTrue(kv_0.equals(kv_1)); + } catch (Exception e) { + fail("Unexpected Exception" + e.getMessage()); + } + + // length -1 + try { + // even if we have a good kv now in dis we will just pass length with -1 for simplicity + KeyValue.create(-1, dis); + fail("Expected corrupt stream"); + } catch (Exception e) { + assertEquals("Failed read -1 bytes, stream corrupt?", e.getMessage()); + } + + } + @Test public void testSplitLogTask() throws DeserializationException { SplitLogTask slt = new SplitLogTask.Unassigned(ServerName.valueOf("mgr,1,1"));