diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java index c8ac48ab186..1063bbdfec1 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java @@ -28,6 +28,7 @@ import org.apache.druid.segment.GenericColumnSerializer; import org.apache.druid.segment.column.ColumnBuilder; import org.apache.druid.segment.data.GenericIndexed; import org.apache.druid.segment.data.ObjectStrategy; +import org.apache.druid.segment.data.SafeWritableMemory; import org.apache.druid.segment.serde.ComplexColumnPartSupplier; import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.apache.druid.segment.serde.ComplexMetricSerde; @@ -70,7 +71,7 @@ public class HllSketchMergeComplexMetricSerde extends ComplexMetricSerde if (object == null) { return null; } - return deserializeSketch(object); + return deserializeSketchSafe(object); } }; } @@ -98,6 +99,18 @@ public class HllSketchMergeComplexMetricSerde extends ComplexMetricSerde throw new IAE("Object is not of a type that can be deserialized to an HllSketch:" + object.getClass().getName()); } + static HllSketch deserializeSketchSafe(final Object object) + { + if (object instanceof String) { + return HllSketch.wrap(SafeWritableMemory.wrap(StringUtils.decodeBase64(((String) object).getBytes(StandardCharsets.UTF_8)))); + } else if (object instanceof byte[]) { + return HllSketch.wrap(SafeWritableMemory.wrap((byte[]) object)); + } else if (object instanceof HllSketch) { + return (HllSketch) object; + } + throw new IAE("Object is not of a type that can be deserialized to an HllSketch:" + object.getClass().getName()); + } + // support large columns @Override public GenericColumnSerializer getSerializer(final SegmentWriteOutMedium segmentWriteOutMedium, final String column) diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java index 34145863fdf..65257b22b79 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java @@ -22,7 +22,9 @@ package org.apache.druid.query.aggregation.datasketches.hll; import org.apache.datasketches.hll.HllSketch; import org.apache.datasketches.memory.Memory; import org.apache.druid.segment.data.ObjectStrategy; +import org.apache.druid.segment.data.SafeWritableMemory; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -55,4 +57,12 @@ public class HllSketchObjectStrategy implements ObjectStrategy return sketch.toCompactByteArray(); } + @Nullable + @Override + public HllSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes) + { + return HllSketch.wrap( + SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes) + ); + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java index 4c18a978560..e5249853ac3 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java @@ -91,7 +91,7 @@ public class KllDoublesSketchComplexMetricSerde extends ComplexMetricSerde if (object == null || object instanceof KllDoublesSketch || object instanceof Memory) { return object; } - return KllDoublesSketchOperations.deserialize(object); + return KllDoublesSketchOperations.deserializeSafe(object); } }; } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java index 97e670a625a..17cb94e2fcf 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java @@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays; import org.apache.datasketches.kll.KllDoublesSketch; import org.apache.datasketches.memory.Memory; import org.apache.druid.segment.data.ObjectStrategy; +import org.apache.druid.segment.data.SafeWritableMemory; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -60,4 +62,15 @@ public class KllDoublesSketchObjectStrategy implements ObjectStrategy return ByteArrays.EMPTY_ARRAY; } } + + @Nullable + @Override + public SketchHolder fromByteBufferSafe(ByteBuffer buffer, int numBytes) + { + if (numBytes == 0) { + return SketchHolder.EMPTY; + } + + return SketchHolder.of( + SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes) + ); + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java index a824312c0ef..4f3ecfae291 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java @@ -59,7 +59,7 @@ public class SketchMergeComplexMetricSerde extends ComplexMetricSerde public SketchHolder extractValue(InputRow inputRow, String metricName) { final Object object = inputRow.getRaw(metricName); - return object == null ? null : SketchHolder.deserialize(object); + return object == null ? null : SketchHolder.deserializeSafe(object); } }; } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java index 19c8da292b4..028bcdc3549 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java @@ -60,7 +60,7 @@ public class ArrayOfDoublesSketchMergeComplexMetricSerde extends ComplexMetricSe if (object == null || object instanceof ArrayOfDoublesSketch) { return object; } - return ArrayOfDoublesSketchOperations.deserialize(object); + return ArrayOfDoublesSketchOperations.deserializeSafe(object); } }; } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java index 1ae950e068f..f893c83b570 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java @@ -23,6 +23,7 @@ import org.apache.datasketches.memory.Memory; import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketch; import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketches; import org.apache.druid.segment.data.ObjectStrategy; +import org.apache.druid.segment.data.SafeWritableMemory; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -48,7 +49,9 @@ public class ArrayOfDoublesSketchObjectStrategy implements ObjectStrategy objectStrategy.fromByteBufferSafe(buf2, garbage2.length).copy() + ); + } + + // non sketch that is too short to contain header should fail with regular java buffer exception + final byte[] garbage = new byte[]{0x01, 0x02}; + final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).copy() + ); + + // non sketch that is long enough to check (this one doesn't actually need 'safe' read) + final byte[] garbageLonger = StringUtils.toUtf8("notasketch"); + final ByteBuffer buf4 = ByteBuffer.wrap(garbageLonger).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + SketchesArgumentException.class, + () -> objectStrategy.fromByteBufferSafe(buf4, garbageLonger.length).copy() + ); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java index 3628c5e6212..0ae46bef496 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java @@ -23,10 +23,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.datasketches.kll.KllDoublesSketch; import org.apache.druid.data.input.MapBasedInputRow; +import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.junit.Assert; import org.junit.Test; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + public class KllDoublesSketchComplexMetricSerdeTest { @Test @@ -92,4 +96,44 @@ public class KllDoublesSketchComplexMetricSerdeTest Assert.assertEquals(1, sketch.getNumRetained()); Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d); } + + @Test + public void testSafeRead() + { + final KllDoublesSketchComplexMetricSerde serde = new KllDoublesSketchComplexMetricSerde(); + final ObjectStrategy objectStrategy = serde.getObjectStrategy(); + + KllDoublesSketch sketch = KllDoublesSketch.newHeapInstance(); + sketch.update(1.1); + sketch.update(1.2); + final byte[] bytes = sketch.toByteArray(); + + ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); + + // valid sketch should not explode when converted to byte array, which reads the memory + objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray(); + + // corrupted sketch should fail with a regular java buffer exception, not all subsets actually fail with the same + // index out of bounds exceptions, but at least this many do + for (int subset = 3; subset < 24; subset++) { + final byte[] garbage2 = new byte[subset]; + for (int i = 0; i < garbage2.length; i++) { + garbage2[i] = buf.get(i); + } + + final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).toByteArray() + ); + } + + // non sketch that is too short to contain header should fail with regular java buffer exception + final byte[] garbage = new byte[]{0x01, 0x02}; + final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).toByteArray() + ); + } } diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperationsTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperationsTest.java new file mode 100644 index 00000000000..d2b0e383984 --- /dev/null +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperationsTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.kll; + +import org.apache.datasketches.kll.KllDoublesSketch; +import org.apache.druid.java.util.common.StringUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class KllDoublesSketchOperationsTest +{ + @Test + public void testDeserializeSafe() + { + KllDoublesSketch sketch = KllDoublesSketch.newHeapInstance(); + sketch.update(1.1); + sketch.update(1.2); + final byte[] bytes = sketch.toByteArray(); + final String base64 = StringUtils.encodeBase64String(bytes); + + Assert.assertArrayEquals(bytes, KllDoublesSketchOperations.deserializeSafe(sketch).toByteArray()); + Assert.assertArrayEquals(bytes, KllDoublesSketchOperations.deserializeSafe(bytes).toByteArray()); + Assert.assertArrayEquals(bytes, KllDoublesSketchOperations.deserializeSafe(base64).toByteArray()); + + final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 20); + Assert.assertThrows(IndexOutOfBoundsException.class, () -> KllDoublesSketchOperations.deserializeSafe(trunacted)); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> KllDoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted)) + ); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java index 5ff441df1c1..c6b8c310221 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java @@ -23,10 +23,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.datasketches.kll.KllFloatsSketch; import org.apache.druid.data.input.MapBasedInputRow; +import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.junit.Assert; import org.junit.Test; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + public class KllFloatsSketchComplexMetricSerdeTest { @Test @@ -92,4 +96,44 @@ public class KllFloatsSketchComplexMetricSerdeTest Assert.assertEquals(1, sketch.getNumRetained()); Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d); } + + @Test + public void testSafeRead() + { + final KllFloatsSketchComplexMetricSerde serde = new KllFloatsSketchComplexMetricSerde(); + final ObjectStrategy objectStrategy = serde.getObjectStrategy(); + + KllFloatsSketch sketch = KllFloatsSketch.newHeapInstance(); + sketch.update(1.1f); + sketch.update(1.2f); + final byte[] bytes = sketch.toByteArray(); + + ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); + + // valid sketch should not explode when converted to byte array, which reads the memory + objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray(); + + // corrupted sketch should fail with a regular java buffer exception, not all subsets actually fail with the same + // index out of bounds exceptions, but at least this many do + for (int subset = 3; subset < 24; subset++) { + final byte[] garbage2 = new byte[subset]; + for (int i = 0; i < garbage2.length; i++) { + garbage2[i] = buf.get(i); + } + + final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).toByteArray() + ); + } + + // non sketch that is too short to contain header should fail with regular java buffer exception + final byte[] garbage = new byte[]{0x01, 0x02}; + final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).toByteArray() + ); + } } diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperationsTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperationsTest.java new file mode 100644 index 00000000000..613b38c6601 --- /dev/null +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperationsTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.kll; + +import org.apache.datasketches.kll.KllFloatsSketch; +import org.apache.druid.java.util.common.StringUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class KllFloatsSketchOperationsTest +{ + @Test + public void testDeserializeSafe() + { + KllFloatsSketch sketch = KllFloatsSketch.newHeapInstance(); + sketch.update(1.1f); + sketch.update(1.2f); + final byte[] bytes = sketch.toByteArray(); + final String base64 = StringUtils.encodeBase64String(bytes); + + Assert.assertArrayEquals(bytes, KllFloatsSketchOperations.deserializeSafe(sketch).toByteArray()); + Assert.assertArrayEquals(bytes, KllFloatsSketchOperations.deserializeSafe(bytes).toByteArray()); + Assert.assertArrayEquals(bytes, KllFloatsSketchOperations.deserializeSafe(base64).toByteArray()); + + final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 20); + Assert.assertThrows(IndexOutOfBoundsException.class, () -> KllFloatsSketchOperations.deserializeSafe(trunacted)); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> KllFloatsSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted)) + ); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java index e198c770425..7dc82baee92 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java @@ -22,11 +22,16 @@ package org.apache.druid.query.aggregation.datasketches.quantiles; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.datasketches.quantiles.DoublesSketch; +import org.apache.datasketches.quantiles.DoublesUnion; import org.apache.druid.data.input.MapBasedInputRow; +import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.junit.Assert; import org.junit.Test; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + public class DoublesSketchComplexMetricSerdeTest { @Test @@ -92,4 +97,42 @@ public class DoublesSketchComplexMetricSerdeTest Assert.assertEquals(1, sketch.getRetainedItems()); Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d); } + + @Test + public void testSafeRead() + { + final DoublesSketchComplexMetricSerde serde = new DoublesSketchComplexMetricSerde(); + DoublesUnion union = DoublesUnion.builder().setMaxK(1024).build(); + union.update(1.1); + final byte[] bytes = union.toByteArray(); + + ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); + ObjectStrategy objectStrategy = serde.getObjectStrategy(); + + // valid sketch should not explode when copied, which reads the memory + objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray(true); + + // corrupted sketch should fail with a regular java buffer exception + for (int subset = 3; subset < 15; subset++) { + final byte[] garbage2 = new byte[subset]; + for (int i = 0; i < garbage2.length; i++) { + garbage2[i] = buf.get(i); + } + + final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + "i " + subset, + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).toByteArray(true) + ); + } + + // non sketch that is too short to contain header should fail with regular java buffer exception + final byte[] garbage = new byte[]{0x01, 0x02}; + final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).toByteArray(true) + ); + } } diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperationsTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperationsTest.java new file mode 100644 index 00000000000..38e5d39a91b --- /dev/null +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperationsTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.quantiles; + +import org.apache.datasketches.quantiles.DoublesUnion; +import org.apache.druid.java.util.common.StringUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class DoublesSketchOperationsTest +{ + @Test + public void testDeserializeSafe() + { + DoublesUnion union = DoublesUnion.builder().setMaxK(1024).build(); + union.update(1.1); + final byte[] bytes = union.getResult().toByteArray(); + final String base64 = StringUtils.encodeBase64String(bytes); + + Assert.assertArrayEquals(bytes, DoublesSketchOperations.deserializeSafe(union.getResult()).toByteArray()); + Assert.assertArrayEquals(bytes, DoublesSketchOperations.deserializeSafe(bytes).toByteArray()); + Assert.assertArrayEquals(bytes, DoublesSketchOperations.deserializeSafe(base64).toByteArray()); + + final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 4); + Assert.assertThrows(IndexOutOfBoundsException.class, () -> DoublesSketchOperations.deserializeSafe(trunacted)); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> DoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64(trunacted)) + ); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategyTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategyTest.java new file mode 100644 index 00000000000..5619facd5f6 --- /dev/null +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategyTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.theta; + +import org.apache.datasketches.Family; +import org.apache.datasketches.SketchesArgumentException; +import org.apache.datasketches.theta.SetOperation; +import org.apache.datasketches.theta.Union; +import org.apache.druid.java.util.common.StringUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class SketchHolderObjectStrategyTest +{ + @Test + public void testSafeRead() + { + SketchHolderObjectStrategy objectStrategy = new SketchHolderObjectStrategy(); + Union union = (Union) SetOperation.builder().setNominalEntries(1024).build(Family.UNION); + union.update(1234L); + + final byte[] bytes = union.getResult().toByteArray(); + + ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); + + // valid sketch should not explode when copied, which reads the memory + objectStrategy.fromByteBufferSafe(buf, bytes.length).getSketch().compact().getCompactBytes(); + + // corrupted sketch should fail with a regular java buffer exception + for (int subset = 3; subset < bytes.length - 1; subset++) { + final byte[] garbage2 = new byte[subset]; + for (int i = 0; i < garbage2.length; i++) { + garbage2[i] = buf.get(i); + } + + final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).getSketch().compact().getCompactBytes() + ); + } + + // non sketch that is too short to contain header should fail with regular java buffer exception + final byte[] garbage = new byte[]{0x01, 0x02}; + final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).getSketch().compact().getCompactBytes() + ); + + // non sketch that is long enough to check (this one doesn't actually need 'safe' read) + final byte[] garbageLonger = StringUtils.toUtf8("notasketch"); + final ByteBuffer buf4 = ByteBuffer.wrap(garbageLonger).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + SketchesArgumentException.class, + () -> objectStrategy.fromByteBufferSafe(buf4, garbageLonger.length).getSketch().compact().getCompactBytes() + ); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderTest.java new file mode 100644 index 00000000000..ef68fdeb8c5 --- /dev/null +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.theta; + +import org.apache.datasketches.Family; +import org.apache.datasketches.theta.SetOperation; +import org.apache.datasketches.theta.Union; +import org.apache.druid.java.util.common.StringUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class SketchHolderTest +{ + @Test + public void testDeserializeSafe() + { + Union union = (Union) SetOperation.builder().setNominalEntries(1024).build(Family.UNION); + union.update(1234L); + final byte[] bytes = union.getResult().toByteArray(); + final String base64 = StringUtils.encodeBase64String(bytes); + + Assert.assertArrayEquals(bytes, SketchHolder.deserializeSafe(union.getResult()).getSketch().toByteArray()); + Assert.assertArrayEquals(bytes, SketchHolder.deserializeSafe(bytes).getSketch().toByteArray()); + Assert.assertArrayEquals(bytes, SketchHolder.deserializeSafe(base64).getSketch().toByteArray()); + + final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 10); + Assert.assertThrows(IndexOutOfBoundsException.class, () -> SketchHolder.deserializeSafe(trunacted)); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> SketchHolder.deserializeSafe(StringUtils.encodeBase64String(trunacted)) + ); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategyTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategyTest.java new file mode 100644 index 00000000000..ee59ddf5764 --- /dev/null +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategyTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.tuple; + +import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketch; +import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketchBuilder; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class ArrayOfDoublesSketchObjectStrategyTest +{ + @Test + public void testSafeRead() + { + ArrayOfDoublesSketchObjectStrategy objectStrategy = new ArrayOfDoublesSketchObjectStrategy(); + ArrayOfDoublesUpdatableSketch sketch = new ArrayOfDoublesUpdatableSketchBuilder().setNominalEntries(1024) + .setNumberOfValues(4) + .build(); + sketch.update(1L, new double[]{1.0, 2.0, 3.0, 4.0}); + + final byte[] bytes = sketch.compact().toByteArray(); + + ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); + + // valid sketch should not explode when copied, which reads the memory + objectStrategy.fromByteBufferSafe(buf, bytes.length).compact().toByteArray(); + + // corrupted sketch should fail with a regular java buffer exception + for (int subset = 3; subset < bytes.length - 1; subset++) { + final byte[] garbage2 = new byte[subset]; + for (int i = 0; i < garbage2.length; i++) { + garbage2[i] = buf.get(i); + } + + final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).compact().toByteArray() + ); + } + + // non sketch that is too short to contain header should fail with regular java buffer exception + final byte[] garbage = new byte[]{0x01, 0x02}; + final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).compact().toByteArray() + ); + } +} diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperationsTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperationsTest.java new file mode 100644 index 00000000000..415f3acab97 --- /dev/null +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperationsTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.tuple; + +import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketch; +import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketchBuilder; +import org.apache.druid.java.util.common.StringUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class ArrayOfDoublesSketchOperationsTest +{ + @Test + public void testDeserializeSafe() + { + ArrayOfDoublesSketchObjectStrategy objectStrategy = new ArrayOfDoublesSketchObjectStrategy(); + ArrayOfDoublesUpdatableSketch sketch = new ArrayOfDoublesUpdatableSketchBuilder().setNominalEntries(1024) + .setNumberOfValues(4) + .build(); + sketch.update(1L, new double[]{1.0, 2.0, 3.0, 4.0}); + + final byte[] bytes = sketch.toByteArray(); + final String base64 = StringUtils.encodeBase64String(bytes); + + Assert.assertArrayEquals(bytes, ArrayOfDoublesSketchOperations.deserializeSafe(sketch).toByteArray()); + Assert.assertArrayEquals(bytes, ArrayOfDoublesSketchOperations.deserializeSafe(bytes).toByteArray()); + Assert.assertArrayEquals(bytes, ArrayOfDoublesSketchOperations.deserializeSafe(base64).toByteArray()); + + final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 10); + Assert.assertThrows(IndexOutOfBoundsException.class, () -> ArrayOfDoublesSketchOperations.deserializeSafe(trunacted)); + Assert.assertThrows( + IndexOutOfBoundsException.class, + () -> ArrayOfDoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted)) + ); + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java b/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java index 351f2665d05..d05ba208585 100644 --- a/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java +++ b/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java @@ -90,6 +90,6 @@ public class ObjectStrategyComplexTypeStrategy implements TypeStrategy @Override public T fromBytes(byte[] value) { - return objectStrategy.fromByteBuffer(ByteBuffer.wrap(value), value.length); + return objectStrategy.fromByteBufferSafe(ByteBuffer.wrap(value), value.length); } } diff --git a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java b/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java index 8a53fc57a7d..eba97d04bbb 100644 --- a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java +++ b/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java @@ -79,4 +79,31 @@ public interface ObjectStrategy extends Comparator out.write(bytes); } } + + /** + * Convert values from their underlying byte representation, when the underlying bytes might be corrupted or + * maliciously constructed + * + * Implementations of this method absolutely must never perform any sun.misc.Unsafe based memory read or write + * operations from instructions contained in the data read from this buffer without first validating the data. If the + * data cannot be validated, all read and write operations from instructions in this data must be done directly with + * the {@link ByteBuffer} methods, or using {@link SafeWritableMemory} if + * {@link org.apache.datasketches.memory.Memory} is employed to materialize the value. + * + * Implementations of this method may change the given buffer's mark, or limit, and position. + * + * Implementations of this method may not store the given buffer in a field of the "deserialized" object, + * need to use {@link ByteBuffer#slice()}, {@link ByteBuffer#asReadOnlyBuffer()} or {@link ByteBuffer#duplicate()} in + * this case. + * + * + * @param buffer buffer to read value from + * @param numBytes number of bytes used to store the value, starting at buffer.position() + * @return an object created from the given byte buffer representation + */ + @Nullable + default T fromByteBufferSafe(ByteBuffer buffer, int numBytes) + { + return fromByteBuffer(buffer, numBytes); + } } diff --git a/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBase.java b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBase.java new file mode 100644 index 00000000000..df2fc14d053 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBase.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import com.google.common.base.Preconditions; +import com.google.common.primitives.Ints; +import org.apache.datasketches.memory.BaseState; +import org.apache.datasketches.memory.MemoryRequestServer; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.datasketches.memory.internal.BaseStateImpl; +import org.apache.datasketches.memory.internal.UnsafeUtil; +import org.apache.datasketches.memory.internal.XxHash64; +import org.apache.druid.java.util.common.StringUtils; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * Base class for making a regular {@link ByteBuffer} look like a {@link org.apache.datasketches.memory.Memory} or + * {@link org.apache.datasketches.memory.Buffer}. All methods delegate directly to the {@link ByteBuffer} rather + * than using 'unsafe' reads. + * + * @see SafeWritableMemory + * @see SafeWritableBuffer + */ + +@SuppressWarnings("unused") +public abstract class SafeWritableBase implements BaseState +{ + static final MemoryRequestServer SAFE_HEAP_REQUEST_SERVER = new HeapByteBufferMemoryRequestServer(); + + final ByteBuffer buffer; + + public SafeWritableBase(ByteBuffer buffer) + { + this.buffer = buffer; + } + + public MemoryRequestServer getMemoryRequestServer() + { + return SAFE_HEAP_REQUEST_SERVER; + } + + public boolean getBoolean(long offsetBytes) + { + return getByte(Ints.checkedCast(offsetBytes)) != 0; + } + + public byte getByte(long offsetBytes) + { + return buffer.get(Ints.checkedCast(offsetBytes)); + } + + public char getChar(long offsetBytes) + { + return buffer.getChar(Ints.checkedCast(offsetBytes)); + } + + public double getDouble(long offsetBytes) + { + return buffer.getDouble(Ints.checkedCast(offsetBytes)); + } + + public float getFloat(long offsetBytes) + { + return buffer.getFloat(Ints.checkedCast(offsetBytes)); + } + + public int getInt(long offsetBytes) + { + return buffer.getInt(Ints.checkedCast(offsetBytes)); + } + + public long getLong(long offsetBytes) + { + return buffer.getLong(Ints.checkedCast(offsetBytes)); + } + + public short getShort(long offsetBytes) + { + return buffer.getShort(Ints.checkedCast(offsetBytes)); + } + + public void putBoolean(long offsetBytes, boolean value) + { + buffer.put(Ints.checkedCast(offsetBytes), (byte) (value ? 1 : 0)); + } + + public void putByte(long offsetBytes, byte value) + { + buffer.put(Ints.checkedCast(offsetBytes), value); + } + + public void putChar(long offsetBytes, char value) + { + buffer.putChar(Ints.checkedCast(offsetBytes), value); + } + + public void putDouble(long offsetBytes, double value) + { + buffer.putDouble(Ints.checkedCast(offsetBytes), value); + } + + public void putFloat(long offsetBytes, float value) + { + buffer.putFloat(Ints.checkedCast(offsetBytes), value); + } + + public void putInt(long offsetBytes, int value) + { + buffer.putInt(Ints.checkedCast(offsetBytes), value); + } + + public void putLong(long offsetBytes, long value) + { + buffer.putLong(Ints.checkedCast(offsetBytes), value); + } + + public void putShort(long offsetBytes, short value) + { + buffer.putShort(Ints.checkedCast(offsetBytes), value); + } + + @Override + public ByteOrder getTypeByteOrder() + { + return buffer.order(); + } + + @Override + public boolean isByteOrderCompatible(ByteOrder byteOrder) + { + return buffer.order().equals(byteOrder); + } + + @Override + public ByteBuffer getByteBuffer() + { + return buffer; + } + + @Override + public long getCapacity() + { + return buffer.capacity(); + } + + @Override + public long getCumulativeOffset() + { + return 0; + } + + @Override + public long getCumulativeOffset(long offsetBytes) + { + return offsetBytes; + } + + @Override + public long getRegionOffset() + { + return 0; + } + + @Override + public long getRegionOffset(long offsetBytes) + { + return offsetBytes; + } + + @Override + public boolean hasArray() + { + return false; + } + + @Override + public long xxHash64(long offsetBytes, long lengthBytes, long seed) + { + return hash(buffer, offsetBytes, lengthBytes, seed); + } + + @Override + public long xxHash64(long in, long seed) + { + return XxHash64.hash(in, seed); + } + + @Override + public boolean hasByteBuffer() + { + return true; + } + + @Override + public boolean isDirect() + { + return false; + } + + @Override + public boolean isReadOnly() + { + return false; + } + + @Override + public boolean isSameResource(Object that) + { + return this.equals(that); + } + + @Override + public boolean isValid() + { + return true; + } + + @Override + public void checkValidAndBounds(long offsetBytes, long lengthBytes) + { + Preconditions.checkArgument( + Ints.checkedCast(offsetBytes) < buffer.limit(), + "start offset %s is greater than buffer limit %s", + offsetBytes, + buffer.limit() + ); + Preconditions.checkArgument( + Ints.checkedCast(offsetBytes + lengthBytes) < buffer.limit(), + "end offset %s is greater than buffer limit %s", + offsetBytes + lengthBytes, + buffer.limit() + ); + } + + /** + * Adapted from {@link BaseStateImpl#toHexString(String, long, int)} + */ + @Override + public String toHexString(String header, long offsetBytes, int lengthBytes) + { + final String klass = this.getClass().getSimpleName(); + final String s1 = StringUtils.format("(..., %d, %d)", offsetBytes, lengthBytes); + final long hcode = hashCode() & 0XFFFFFFFFL; + final String call = ".toHexString" + s1 + ", hashCode: " + hcode; + String sb = "### " + klass + " SUMMARY ###" + UnsafeUtil.LS + + "Header Comment : " + header + UnsafeUtil.LS + + "Call Parameters : " + call; + return toHex(this, sb, offsetBytes, lengthBytes); + } + + /** + * Adapted from {@link BaseStateImpl#toHex(BaseStateImpl, String, long, int)} + */ + static String toHex( + final SafeWritableBase state, + final String preamble, + final long offsetBytes, + final int lengthBytes + ) + { + final String lineSeparator = UnsafeUtil.LS; + final long capacity = state.getCapacity(); + UnsafeUtil.checkBounds(offsetBytes, lengthBytes, capacity); + final StringBuilder sb = new StringBuilder(); + final String uObjStr; + final long uObjHeader; + uObjStr = "null"; + uObjHeader = 0; + final ByteBuffer bb = state.getByteBuffer(); + final String bbStr = bb == null ? "null" + : bb.getClass().getSimpleName() + ", " + (bb.hashCode() & 0XFFFFFFFFL); + final MemoryRequestServer memReqSvr = state.getMemoryRequestServer(); + final String memReqStr = memReqSvr != null + ? memReqSvr.getClass().getSimpleName() + ", " + (memReqSvr.hashCode() & 0XFFFFFFFFL) + : "null"; + final long cumBaseOffset = state.getCumulativeOffset(); + sb.append(preamble).append(lineSeparator); + sb.append("UnsafeObj, hashCode : ").append(uObjStr).append(lineSeparator); + sb.append("UnsafeObjHeader : ").append(uObjHeader).append(lineSeparator); + sb.append("ByteBuf, hashCode : ").append(bbStr).append(lineSeparator); + sb.append("RegionOffset : ").append(state.getRegionOffset()).append(lineSeparator); + sb.append("Capacity : ").append(capacity).append(lineSeparator); + sb.append("CumBaseOffset : ").append(cumBaseOffset).append(lineSeparator); + sb.append("MemReq, hashCode : ").append(memReqStr).append(lineSeparator); + sb.append("Valid : ").append(state.isValid()).append(lineSeparator); + sb.append("Read Only : ").append(state.isReadOnly()).append(lineSeparator); + sb.append("Type Byte Order : ").append(state.getTypeByteOrder()).append(lineSeparator); + sb.append("Native Byte Order : ").append(ByteOrder.nativeOrder()).append(lineSeparator); + sb.append("JDK Runtime Version : ").append(UnsafeUtil.JDK).append(lineSeparator); + //Data detail + sb.append("Data, littleEndian : 0 1 2 3 4 5 6 7"); + + for (long i = 0; i < lengthBytes; i++) { + final int b = state.getByte(cumBaseOffset + offsetBytes + i) & 0XFF; + if (i % 8 == 0) { //row header + sb.append(StringUtils.format("%n%20s: ", offsetBytes + i)); + } + sb.append(StringUtils.format("%02x ", b)); + } + sb.append(lineSeparator); + + return sb.toString(); + } + + // copied from datasketches-memory XxHash64.java + private static final long P1 = -7046029288634856825L; + private static final long P2 = -4417276706812531889L; + private static final long P3 = 1609587929392839161L; + private static final long P4 = -8796714831421723037L; + private static final long P5 = 2870177450012600261L; + + /** + * Adapted from {@link XxHash64#hash(Object, long, long, long)} to work with {@link ByteBuffer} + */ + static long hash(ByteBuffer memory, long cumOffsetBytes, final long lengthBytes, final long seed) + { + long hash; + long remaining = lengthBytes; + int offset = Ints.checkedCast(cumOffsetBytes); + + if (remaining >= 32) { + long v1 = seed + P1 + P2; + long v2 = seed + P2; + long v3 = seed; + long v4 = seed - P1; + + do { + v1 += memory.getLong(offset) * P2; + v1 = Long.rotateLeft(v1, 31); + v1 *= P1; + + v2 += memory.getLong(offset + 8) * P2; + v2 = Long.rotateLeft(v2, 31); + v2 *= P1; + + v3 += memory.getLong(offset + 16) * P2; + v3 = Long.rotateLeft(v3, 31); + v3 *= P1; + + v4 += memory.getLong(offset + 24) * P2; + v4 = Long.rotateLeft(v4, 31); + v4 *= P1; + + offset += 32; + remaining -= 32; + } while (remaining >= 32); + + hash = Long.rotateLeft(v1, 1) + + Long.rotateLeft(v2, 7) + + Long.rotateLeft(v3, 12) + + Long.rotateLeft(v4, 18); + + v1 *= P2; + v1 = Long.rotateLeft(v1, 31); + v1 *= P1; + hash ^= v1; + hash = (hash * P1) + P4; + + v2 *= P2; + v2 = Long.rotateLeft(v2, 31); + v2 *= P1; + hash ^= v2; + hash = (hash * P1) + P4; + + v3 *= P2; + v3 = Long.rotateLeft(v3, 31); + v3 *= P1; + hash ^= v3; + hash = (hash * P1) + P4; + + v4 *= P2; + v4 = Long.rotateLeft(v4, 31); + v4 *= P1; + hash ^= v4; + hash = (hash * P1) + P4; + } else { //end remaining >= 32 + hash = seed + P5; + } + + hash += lengthBytes; + + while (remaining >= 8) { + long k1 = memory.getLong(offset); + k1 *= P2; + k1 = Long.rotateLeft(k1, 31); + k1 *= P1; + hash ^= k1; + hash = (Long.rotateLeft(hash, 27) * P1) + P4; + offset += 8; + remaining -= 8; + } + + if (remaining >= 4) { //treat as unsigned ints + hash ^= (memory.getInt(offset) & 0XFFFF_FFFFL) * P1; + hash = (Long.rotateLeft(hash, 23) * P2) + P3; + offset += 4; + remaining -= 4; + } + + while (remaining != 0) { //treat as unsigned bytes + hash ^= (memory.get(offset) & 0XFFL) * P5; + hash = Long.rotateLeft(hash, 11) * P1; + --remaining; + ++offset; + } + + hash ^= hash >>> 33; + hash *= P2; + hash ^= hash >>> 29; + hash *= P3; + hash ^= hash >>> 32; + return hash; + } + + private static class HeapByteBufferMemoryRequestServer implements MemoryRequestServer + { + @Override + public WritableMemory request(WritableMemory currentWritableMemory, long capacityBytes) + { + ByteBuffer newBuffer = ByteBuffer.allocate(Ints.checkedCast(capacityBytes)); + newBuffer.order(currentWritableMemory.getTypeByteOrder()); + return new SafeWritableMemory(newBuffer); + } + + @Override + public void requestClose(WritableMemory memToClose, WritableMemory newMemory) + { + // do nothing + } + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBuffer.java b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBuffer.java new file mode 100644 index 00000000000..3da7e70b457 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBuffer.java @@ -0,0 +1,501 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import com.google.common.primitives.Ints; +import org.apache.datasketches.memory.BaseBuffer; +import org.apache.datasketches.memory.Buffer; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.memory.WritableBuffer; +import org.apache.datasketches.memory.WritableMemory; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * Safety first! Don't trust something whose contents you locations to read and write stuff to, but need a + * {@link Buffer} or {@link WritableBuffer}? use this! + *

+ * Delegates everything to an underlying {@link ByteBuffer} so all read and write operations will have bounds checks + * built in rather than using 'unsafe'. + */ +public class SafeWritableBuffer extends SafeWritableBase implements WritableBuffer +{ + private int start; + private int end; + + public SafeWritableBuffer(ByteBuffer buffer) + { + super(buffer); + this.start = 0; + this.buffer.position(0); + this.end = buffer.capacity(); + } + + @Override + public WritableBuffer writableDuplicate() + { + return writableDuplicate(buffer.order()); + } + + @Override + public WritableBuffer writableDuplicate(ByteOrder byteOrder) + { + ByteBuffer dupe = buffer.duplicate(); + dupe.order(byteOrder); + WritableBuffer duplicate = new SafeWritableBuffer(dupe); + duplicate.setStartPositionEnd(start, buffer.position(), end); + return duplicate; + } + + @Override + public WritableBuffer writableRegion() + { + ByteBuffer dupe = buffer.duplicate().order(buffer.order()); + dupe.position(start); + dupe.limit(end); + ByteBuffer remaining = buffer.slice(); + remaining.order(dupe.order()); + return new SafeWritableBuffer(remaining); + } + + @Override + public WritableBuffer writableRegion(long offsetBytes, long capacityBytes, ByteOrder byteOrder) + { + ByteBuffer dupe = buffer.duplicate(); + dupe.position(Ints.checkedCast(offsetBytes)); + dupe.limit(dupe.position() + Ints.checkedCast(capacityBytes)); + return new SafeWritableBuffer(dupe.slice().order(byteOrder)); + } + + @Override + public WritableMemory asWritableMemory(ByteOrder byteOrder) + { + ByteBuffer dupe = buffer.duplicate(); + dupe.order(byteOrder); + return new SafeWritableMemory(dupe); + } + + @Override + public void putBoolean(boolean value) + { + buffer.put((byte) (value ? 1 : 0)); + } + + @Override + public void putBooleanArray(boolean[] srcArray, int srcOffsetBooleans, int lengthBooleans) + { + for (int i = 0; i < lengthBooleans; i++) { + putBoolean(srcArray[srcOffsetBooleans + i]); + } + } + + @Override + public void putByte(byte value) + { + buffer.put(value); + } + + @Override + public void putByteArray(byte[] srcArray, int srcOffsetBytes, int lengthBytes) + { + buffer.put(srcArray, srcOffsetBytes, lengthBytes); + } + + @Override + public void putChar(char value) + { + buffer.putChar(value); + } + + @Override + public void putCharArray(char[] srcArray, int srcOffsetChars, int lengthChars) + { + for (int i = 0; i < lengthChars; i++) { + buffer.putChar(srcArray[srcOffsetChars + i]); + } + } + + @Override + public void putDouble(double value) + { + buffer.putDouble(value); + } + + @Override + public void putDoubleArray(double[] srcArray, int srcOffsetDoubles, int lengthDoubles) + { + for (int i = 0; i < lengthDoubles; i++) { + buffer.putDouble(srcArray[srcOffsetDoubles + i]); + } + } + + @Override + public void putFloat(float value) + { + buffer.putFloat(value); + } + + @Override + public void putFloatArray(float[] srcArray, int srcOffsetFloats, int lengthFloats) + { + for (int i = 0; i < lengthFloats; i++) { + buffer.putFloat(srcArray[srcOffsetFloats + i]); + } + } + + @Override + public void putInt(int value) + { + buffer.putInt(value); + } + + @Override + public void putIntArray(int[] srcArray, int srcOffsetInts, int lengthInts) + { + for (int i = 0; i < lengthInts; i++) { + buffer.putInt(srcArray[srcOffsetInts + i]); + } + } + + @Override + public void putLong(long value) + { + buffer.putLong(value); + } + + @Override + public void putLongArray(long[] srcArray, int srcOffsetLongs, int lengthLongs) + { + for (int i = 0; i < lengthLongs; i++) { + buffer.putLong(srcArray[srcOffsetLongs + i]); + } + } + + @Override + public void putShort(short value) + { + buffer.putShort(value); + } + + @Override + public void putShortArray(short[] srcArray, int srcOffsetShorts, int lengthShorts) + { + for (int i = 0; i < lengthShorts; i++) { + buffer.putShort(srcArray[srcOffsetShorts + i]); + } + } + + @Override + public Object getArray() + { + return null; + } + + @Override + public void clear() + { + fill((byte) 0); + } + + @Override + public void fill(byte value) + { + while (buffer.hasRemaining() && buffer.position() < end) { + buffer.put(value); + } + } + + @Override + public Buffer duplicate() + { + return writableDuplicate(); + } + + @Override + public Buffer duplicate(ByteOrder byteOrder) + { + return writableDuplicate(byteOrder); + } + + @Override + public Buffer region() + { + return writableRegion(); + } + + @Override + public Buffer region(long offsetBytes, long capacityBytes, ByteOrder byteOrder) + { + return writableRegion(offsetBytes, capacityBytes, byteOrder); + } + + @Override + public Memory asMemory(ByteOrder byteOrder) + { + return asWritableMemory(byteOrder); + } + + @Override + public boolean getBoolean() + { + return buffer.get() == 0 ? false : true; + } + + @Override + public void getBooleanArray(boolean[] dstArray, int dstOffsetBooleans, int lengthBooleans) + { + for (int i = 0; i < lengthBooleans; i++) { + dstArray[dstOffsetBooleans + i] = getBoolean(); + } + } + + @Override + public byte getByte() + { + return buffer.get(); + } + + @Override + public void getByteArray(byte[] dstArray, int dstOffsetBytes, int lengthBytes) + { + for (int i = 0; i < lengthBytes; i++) { + dstArray[dstOffsetBytes + i] = buffer.get(); + } + } + + @Override + public char getChar() + { + return buffer.getChar(); + } + + @Override + public void getCharArray(char[] dstArray, int dstOffsetChars, int lengthChars) + { + for (int i = 0; i < lengthChars; i++) { + dstArray[dstOffsetChars + i] = buffer.getChar(); + } + } + + @Override + public double getDouble() + { + return buffer.getDouble(); + } + + @Override + public void getDoubleArray(double[] dstArray, int dstOffsetDoubles, int lengthDoubles) + { + for (int i = 0; i < lengthDoubles; i++) { + dstArray[dstOffsetDoubles + i] = buffer.getDouble(); + } + } + + @Override + public float getFloat() + { + return buffer.getFloat(); + } + + @Override + public void getFloatArray(float[] dstArray, int dstOffsetFloats, int lengthFloats) + { + for (int i = 0; i < lengthFloats; i++) { + dstArray[dstOffsetFloats + i] = buffer.getFloat(); + } + } + + @Override + public int getInt() + { + return buffer.getInt(); + } + + @Override + public void getIntArray(int[] dstArray, int dstOffsetInts, int lengthInts) + { + for (int i = 0; i < lengthInts; i++) { + dstArray[dstOffsetInts + i] = buffer.getInt(); + } + } + + @Override + public long getLong() + { + return buffer.getLong(); + } + + @Override + public void getLongArray(long[] dstArray, int dstOffsetLongs, int lengthLongs) + { + for (int i = 0; i < lengthLongs; i++) { + dstArray[dstOffsetLongs + i] = buffer.getLong(); + } + } + + @Override + public short getShort() + { + return buffer.getShort(); + } + + @Override + public void getShortArray(short[] dstArray, int dstOffsetShorts, int lengthShorts) + { + for (int i = 0; i < lengthShorts; i++) { + dstArray[dstOffsetShorts + i] = buffer.getShort(); + } + } + + @Override + public int compareTo( + long thisOffsetBytes, + long thisLengthBytes, + Buffer that, + long thatOffsetBytes, + long thatLengthBytes + ) + { + final int thisLength = Ints.checkedCast(thisLengthBytes); + final int thatLength = Ints.checkedCast(thatLengthBytes); + + final int commonLength = Math.min(thisLength, thatLength); + + for (int i = 0; i < commonLength; i++) { + final int cmp = Byte.compare(getByte(thisOffsetBytes + i), that.getByte(thatOffsetBytes + i)); + if (cmp != 0) { + return cmp; + } + } + + return Integer.compare(thisLength, thatLength); + } + + @Override + public BaseBuffer incrementPosition(long increment) + { + buffer.position(buffer.position() + Ints.checkedCast(increment)); + return this; + } + + @Override + public BaseBuffer incrementAndCheckPosition(long increment) + { + checkInvariants(start, buffer.position() + increment, end, buffer.capacity()); + return incrementPosition(increment); + } + + @Override + public long getEnd() + { + return end; + } + + @Override + public long getPosition() + { + return buffer.position(); + } + + @Override + public long getStart() + { + return start; + } + + @Override + public long getRemaining() + { + return buffer.remaining(); + } + + @Override + public boolean hasRemaining() + { + return buffer.hasRemaining(); + } + + @Override + public BaseBuffer resetPosition() + { + buffer.position(start); + return this; + } + + @Override + public BaseBuffer setPosition(long position) + { + buffer.position(Ints.checkedCast(position)); + return this; + } + + @Override + public BaseBuffer setAndCheckPosition(long position) + { + checkInvariants(start, position, end, buffer.capacity()); + return setPosition(position); + } + + @Override + public BaseBuffer setStartPositionEnd(long start, long position, long end) + { + this.start = Ints.checkedCast(start); + this.end = Ints.checkedCast(end); + buffer.position(Ints.checkedCast(position)); + buffer.limit(this.end); + return this; + } + + @Override + public BaseBuffer setAndCheckStartPositionEnd(long start, long position, long end) + { + checkInvariants(start, position, end, buffer.capacity()); + return setStartPositionEnd(start, position, end); + } + + @Override + public boolean equalTo(long thisOffsetBytes, Object that, long thatOffsetBytes, long lengthBytes) + { + if (!(that instanceof SafeWritableBuffer)) { + return false; + } + return compareTo(thisOffsetBytes, lengthBytes, (SafeWritableBuffer) that, thatOffsetBytes, lengthBytes) == 0; + } + + /** + * Adapted from {@link org.apache.datasketches.memory.internal.BaseBufferImpl#checkInvariants(long, long, long, long)} + */ + static void checkInvariants(final long start, final long pos, final long end, final long cap) + { + if ((start | pos | end | cap | (pos - start) | (end - pos) | (cap - end)) < 0L) { + throw new IllegalArgumentException( + "Violation of Invariants: " + + "start: " + start + + " <= pos: " + pos + + " <= end: " + end + + " <= cap: " + cap + + "; (pos - start): " + (pos - start) + + ", (end - pos): " + (end - pos) + + ", (cap - end): " + (cap - end) + ); + } + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/data/SafeWritableMemory.java b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableMemory.java new file mode 100644 index 00000000000..9006ac5cec9 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableMemory.java @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import com.google.common.primitives.Ints; +import org.apache.datasketches.memory.Buffer; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.memory.Utf8CodingException; +import org.apache.datasketches.memory.WritableBuffer; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.druid.java.util.common.StringUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.WritableByteChannel; + +/** + * Safety first! Don't trust something whose contents you locations to read and write stuff to, but need a + * {@link Memory} or {@link WritableMemory}? use this! + *

+ * Delegates everything to an underlying {@link ByteBuffer} so all read and write operations will have bounds checks + * built in rather than using 'unsafe'. + */ +public class SafeWritableMemory extends SafeWritableBase implements WritableMemory +{ + public static SafeWritableMemory wrap(byte[] bytes) + { + return wrap(ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()), 0, bytes.length); + } + + public static SafeWritableMemory wrap(ByteBuffer buffer) + { + return wrap(buffer.duplicate().order(buffer.order()), 0, buffer.capacity()); + } + + public static SafeWritableMemory wrap(ByteBuffer buffer, ByteOrder byteOrder) + { + return wrap(buffer.duplicate().order(byteOrder), 0, buffer.capacity()); + } + + public static SafeWritableMemory wrap(ByteBuffer buffer, int offset, int size) + { + final ByteBuffer dupe = buffer.duplicate().order(buffer.order()); + dupe.position(offset); + dupe.limit(offset + size); + return new SafeWritableMemory(dupe.slice().order(buffer.order())); + } + + public SafeWritableMemory(ByteBuffer buffer) + { + super(buffer); + } + + @Override + public Memory region(long offsetBytes, long capacityBytes, ByteOrder byteOrder) + { + return writableRegion(offsetBytes, capacityBytes, byteOrder); + } + + @Override + public Buffer asBuffer(ByteOrder byteOrder) + { + return asWritableBuffer(byteOrder); + } + + @Override + public void getBooleanArray(long offsetBytes, boolean[] dstArray, int dstOffsetBooleans, int lengthBooleans) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthBooleans; j++) { + dstArray[dstOffsetBooleans + j] = buffer.get(offset + j) != 0; + } + } + + @Override + public void getByteArray(long offsetBytes, byte[] dstArray, int dstOffsetBytes, int lengthBytes) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthBytes; j++) { + dstArray[dstOffsetBytes + j] = buffer.get(offset + j); + } + } + + @Override + public void getCharArray(long offsetBytes, char[] dstArray, int dstOffsetChars, int lengthChars) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthChars; j++) { + dstArray[dstOffsetChars + j] = buffer.getChar(offset + (j * Character.BYTES)); + } + } + + @Override + public int getCharsFromUtf8(long offsetBytes, int utf8LengthBytes, Appendable dst) + throws IOException, Utf8CodingException + { + ByteBuffer dupe = buffer.asReadOnlyBuffer().order(buffer.order()); + dupe.position(Ints.checkedCast(offsetBytes)); + String s = StringUtils.fromUtf8(dupe, utf8LengthBytes); + dst.append(s); + return s.length(); + } + + @Override + public int getCharsFromUtf8(long offsetBytes, int utf8LengthBytes, StringBuilder dst) throws Utf8CodingException + { + ByteBuffer dupe = buffer.asReadOnlyBuffer().order(buffer.order()); + dupe.position(Ints.checkedCast(offsetBytes)); + String s = StringUtils.fromUtf8(dupe, utf8LengthBytes); + dst.append(s); + return s.length(); + } + + @Override + public void getDoubleArray(long offsetBytes, double[] dstArray, int dstOffsetDoubles, int lengthDoubles) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthDoubles; j++) { + dstArray[dstOffsetDoubles + j] = buffer.getDouble(offset + (j * Double.BYTES)); + } + } + + @Override + public void getFloatArray(long offsetBytes, float[] dstArray, int dstOffsetFloats, int lengthFloats) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthFloats; j++) { + dstArray[dstOffsetFloats + j] = buffer.getFloat(offset + (j * Float.BYTES)); + } + } + + @Override + public void getIntArray(long offsetBytes, int[] dstArray, int dstOffsetInts, int lengthInts) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthInts; j++) { + dstArray[dstOffsetInts + j] = buffer.getInt(offset + (j * Integer.BYTES)); + } + } + + @Override + public void getLongArray(long offsetBytes, long[] dstArray, int dstOffsetLongs, int lengthLongs) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthLongs; j++) { + dstArray[dstOffsetLongs + j] = buffer.getLong(offset + (j * Long.BYTES)); + } + } + + @Override + public void getShortArray(long offsetBytes, short[] dstArray, int dstOffsetShorts, int lengthShorts) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int j = 0; j < lengthShorts; j++) { + dstArray[dstOffsetShorts + j] = buffer.getShort(offset + (j * Short.BYTES)); + } + } + + @Override + public int compareTo( + long thisOffsetBytes, + long thisLengthBytes, + Memory that, + long thatOffsetBytes, + long thatLengthBytes + ) + { + final int thisLength = Ints.checkedCast(thisLengthBytes); + final int thatLength = Ints.checkedCast(thatLengthBytes); + + final int commonLength = Math.min(thisLength, thatLength); + + for (int i = 0; i < commonLength; i++) { + final int cmp = Byte.compare(getByte(thisOffsetBytes + i), that.getByte(thatOffsetBytes + i)); + if (cmp != 0) { + return cmp; + } + } + + return Integer.compare(thisLength, thatLength); + } + + @Override + public void copyTo(long srcOffsetBytes, WritableMemory destination, long dstOffsetBytes, long lengthBytes) + { + int offset = Ints.checkedCast(srcOffsetBytes); + for (int i = 0; i < lengthBytes; i++) { + destination.putByte(dstOffsetBytes + i, buffer.get(offset + i)); + } + } + + @Override + public void writeTo(long offsetBytes, long lengthBytes, WritableByteChannel out) throws IOException + { + ByteBuffer dupe = buffer.duplicate(); + dupe.position(Ints.checkedCast(offsetBytes)); + dupe.limit(dupe.position() + Ints.checkedCast(lengthBytes)); + ByteBuffer view = dupe.slice(); + view.order(buffer.order()); + out.write(view); + } + + @Override + public boolean equalTo(long thisOffsetBytes, Object that, long thatOffsetBytes, long lengthBytes) + { + if (!(that instanceof SafeWritableMemory)) { + return false; + } + return compareTo(thisOffsetBytes, lengthBytes, (SafeWritableMemory) that, thatOffsetBytes, lengthBytes) == 0; + } + + + @Override + public WritableMemory writableRegion(long offsetBytes, long capacityBytes, ByteOrder byteOrder) + { + final ByteBuffer dupe = buffer.duplicate().order(buffer.order()); + final int sizeBytes = Ints.checkedCast(capacityBytes); + dupe.position(Ints.checkedCast(offsetBytes)); + dupe.limit(dupe.position() + sizeBytes); + final ByteBuffer view = dupe.slice(); + view.order(byteOrder); + return new SafeWritableMemory(view); + } + + @Override + public WritableBuffer asWritableBuffer(ByteOrder byteOrder) + { + return new SafeWritableBuffer(buffer.duplicate().order(byteOrder)); + } + + @Override + public void putBooleanArray(long offsetBytes, boolean[] srcArray, int srcOffsetBooleans, int lengthBooleans) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthBooleans; i++) { + buffer.put(offset + i, (byte) (srcArray[i + srcOffsetBooleans] ? 1 : 0)); + } + } + + @Override + public void putByteArray(long offsetBytes, byte[] srcArray, int srcOffsetBytes, int lengthBytes) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthBytes; i++) { + buffer.put(offset + i, srcArray[srcOffsetBytes + i]); + } + } + + @Override + public void putCharArray(long offsetBytes, char[] srcArray, int srcOffsetChars, int lengthChars) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthChars; i++) { + buffer.putChar(offset + (i * Character.BYTES), srcArray[srcOffsetChars + i]); + } + } + + @Override + public long putCharsToUtf8(long offsetBytes, CharSequence src) + { + final byte[] bytes = StringUtils.toUtf8(src.toString()); + putByteArray(offsetBytes, bytes, 0, bytes.length); + return bytes.length; + } + + @Override + public void putDoubleArray(long offsetBytes, double[] srcArray, int srcOffsetDoubles, int lengthDoubles) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthDoubles; i++) { + buffer.putDouble(offset + (i * Double.BYTES), srcArray[srcOffsetDoubles + i]); + } + } + + @Override + public void putFloatArray(long offsetBytes, float[] srcArray, int srcOffsetFloats, int lengthFloats) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthFloats; i++) { + buffer.putFloat(offset + (i * Float.BYTES), srcArray[srcOffsetFloats + i]); + } + } + + @Override + public void putIntArray(long offsetBytes, int[] srcArray, int srcOffsetInts, int lengthInts) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthInts; i++) { + buffer.putInt(offset + (i * Integer.BYTES), srcArray[srcOffsetInts + i]); + } + } + + @Override + public void putLongArray(long offsetBytes, long[] srcArray, int srcOffsetLongs, int lengthLongs) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthLongs; i++) { + buffer.putLong(offset + (i * Long.BYTES), srcArray[srcOffsetLongs + i]); + } + } + + @Override + public void putShortArray(long offsetBytes, short[] srcArray, int srcOffsetShorts, int lengthShorts) + { + final int offset = Ints.checkedCast(offsetBytes); + for (int i = 0; i < lengthShorts; i++) { + buffer.putShort(offset + (i * Short.BYTES), srcArray[srcOffsetShorts + i]); + } + } + + @Override + public long getAndAddLong(long offsetBytes, long delta) + { + final int offset = Ints.checkedCast(offsetBytes); + final long currentValue; + synchronized (buffer) { + currentValue = buffer.getLong(offset); + buffer.putLong(offset, currentValue + delta); + } + return currentValue; + } + + @Override + public boolean compareAndSwapLong(long offsetBytes, long expect, long update) + { + final int offset = Ints.checkedCast(offsetBytes); + synchronized (buffer) { + final long actual = buffer.getLong(offset); + if (expect == actual) { + buffer.putLong(offset, update); + return true; + } + } + return false; + } + + @Override + public long getAndSetLong(long offsetBytes, long newValue) + { + int offset = Ints.checkedCast(offsetBytes); + synchronized (buffer) { + long l = buffer.getLong(offset); + buffer.putLong(offset, newValue); + return l; + } + } + + @Override + public Object getArray() + { + return null; + } + + @Override + public void clear() + { + fill((byte) 0); + } + + @Override + public void clear(long offsetBytes, long lengthBytes) + { + fill(offsetBytes, lengthBytes, (byte) 0); + } + + @Override + public void clearBits(long offsetBytes, byte bitMask) + { + final int offset = Ints.checkedCast(offsetBytes); + int value = buffer.get(offset) & 0XFF; + value &= ~bitMask; + buffer.put(offset, (byte) value); + } + + @Override + public void fill(byte value) + { + for (int i = 0; i < buffer.capacity(); i++) { + buffer.put(i, value); + } + } + + @Override + public void fill(long offsetBytes, long lengthBytes, byte value) + { + int offset = Ints.checkedCast(offsetBytes); + int length = Ints.checkedCast(lengthBytes); + for (int i = 0; i < length; i++) { + buffer.put(offset + i, value); + } + } + + @Override + public void setBits(long offsetBytes, byte bitMask) + { + final int offset = Ints.checkedCast(offsetBytes); + buffer.put(offset, (byte) (buffer.get(offset) | bitMask)); + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/data/SafeWritableBufferTest.java b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableBufferTest.java new file mode 100644 index 00000000000..f432b7c167c --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableBufferTest.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import org.apache.datasketches.memory.Buffer; +import org.apache.datasketches.memory.WritableBuffer; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class SafeWritableBufferTest +{ + private static final int CAPACITY = 1024; + + @Test + public void testPutAndGet() + { + WritableBuffer b1 = getBuffer(); + Assert.assertEquals(0, b1.getPosition()); + b1.putByte((byte) 0x01); + Assert.assertEquals(1, b1.getPosition()); + b1.putBoolean(true); + Assert.assertEquals(2, b1.getPosition()); + b1.putBoolean(false); + Assert.assertEquals(3, b1.getPosition()); + b1.putChar('c'); + Assert.assertEquals(5, b1.getPosition()); + b1.putDouble(1.1); + Assert.assertEquals(13, b1.getPosition()); + b1.putFloat(1.1f); + Assert.assertEquals(17, b1.getPosition()); + b1.putInt(100); + Assert.assertEquals(21, b1.getPosition()); + b1.putLong(1000L); + Assert.assertEquals(29, b1.getPosition()); + b1.putShort((short) 15); + Assert.assertEquals(31, b1.getPosition()); + b1.resetPosition(); + + Assert.assertEquals(0x01, b1.getByte()); + Assert.assertTrue(b1.getBoolean()); + Assert.assertFalse(b1.getBoolean()); + Assert.assertEquals('c', b1.getChar()); + Assert.assertEquals(1.1, b1.getDouble(), 0.0); + Assert.assertEquals(1.1f, b1.getFloat(), 0.0); + Assert.assertEquals(100, b1.getInt()); + Assert.assertEquals(1000L, b1.getLong()); + Assert.assertEquals(15, b1.getShort()); + } + + @Test + public void testPutAndGetArrays() + { + WritableBuffer buffer = getBuffer(); + final byte[] b1 = new byte[]{0x01, 0x02, 0x08, 0x08}; + final byte[] b2 = new byte[b1.length]; + + final boolean[] bool1 = new boolean[]{true, false, false, true}; + final boolean[] bool2 = new boolean[bool1.length]; + + final char[] chars1 = new char[]{'a', 'b', 'c', 'd'}; + final char[] chars2 = new char[chars1.length]; + + final double[] double1 = new double[]{1.1, -2.2, 3.3, 4.4}; + final double[] double2 = new double[double1.length]; + + final float[] float1 = new float[]{1.1f, 2.2f, -3.3f, 4.4f}; + final float[] float2 = new float[float1.length]; + + final int[] ints1 = new int[]{1, 2, -3, 4}; + final int[] ints2 = new int[ints1.length]; + + final long[] longs1 = new long[]{1L, -2L, 3L, -14L}; + final long[] longs2 = new long[ints1.length]; + + final short[] shorts1 = new short[]{1, -2, 3, -14}; + final short[] shorts2 = new short[ints1.length]; + + buffer.putByteArray(b1, 0, 2); + buffer.putByteArray(b1, 2, b1.length - 2); + buffer.putBooleanArray(bool1, 0, bool1.length); + buffer.putCharArray(chars1, 0, chars1.length); + buffer.putDoubleArray(double1, 0, double1.length); + buffer.putFloatArray(float1, 0, float1.length); + buffer.putIntArray(ints1, 0, ints1.length); + buffer.putLongArray(longs1, 0, longs1.length); + buffer.putShortArray(shorts1, 0, shorts1.length); + long pos = buffer.getPosition(); + buffer.resetPosition(); + buffer.getByteArray(b2, 0, b1.length); + buffer.getBooleanArray(bool2, 0, bool1.length); + buffer.getCharArray(chars2, 0, chars1.length); + buffer.getDoubleArray(double2, 0, double1.length); + buffer.getFloatArray(float2, 0, float1.length); + buffer.getIntArray(ints2, 0, ints1.length); + buffer.getLongArray(longs2, 0, longs1.length); + buffer.getShortArray(shorts2, 0, shorts1.length); + + Assert.assertArrayEquals(b1, b2); + Assert.assertArrayEquals(bool1, bool2); + Assert.assertArrayEquals(chars1, chars2); + for (int i = 0; i < double1.length; i++) { + Assert.assertEquals(double1[i], double2[i], 0.0); + } + for (int i = 0; i < float1.length; i++) { + Assert.assertEquals(float1[i], float2[i], 0.0); + } + Assert.assertArrayEquals(ints1, ints2); + Assert.assertArrayEquals(longs1, longs2); + Assert.assertArrayEquals(shorts1, shorts2); + + Assert.assertEquals(pos, buffer.getPosition()); + } + + @Test + public void testStartEndRegionAndDuplicate() + { + WritableBuffer buffer = getBuffer(); + Assert.assertEquals(0, buffer.getPosition()); + Assert.assertEquals(0, buffer.getStart()); + Assert.assertEquals(CAPACITY, buffer.getEnd()); + Assert.assertEquals(CAPACITY, buffer.getRemaining()); + Assert.assertEquals(CAPACITY, buffer.getCapacity()); + Assert.assertTrue(buffer.hasRemaining()); + buffer.fill((byte) 0x07); + buffer.setAndCheckStartPositionEnd(10L, 15L, 100L); + Assert.assertEquals(15L, buffer.getPosition()); + Assert.assertEquals(10L, buffer.getStart()); + Assert.assertEquals(100L, buffer.getEnd()); + Assert.assertEquals(85L, buffer.getRemaining()); + Assert.assertEquals(CAPACITY, buffer.getCapacity()); + buffer.fill((byte) 0x70); + buffer.resetPosition(); + Assert.assertEquals(10L, buffer.getPosition()); + for (int i = 0; i < 90; i++) { + if (i < 5) { + Assert.assertEquals(0x07, buffer.getByte()); + } else { + Assert.assertEquals(0x70, buffer.getByte()); + } + } + buffer.setAndCheckPosition(50); + + Buffer duplicate = buffer.duplicate(); + Assert.assertEquals(buffer.getStart(), duplicate.getStart()); + Assert.assertEquals(buffer.getPosition(), duplicate.getPosition()); + Assert.assertEquals(buffer.getEnd(), duplicate.getEnd()); + Assert.assertEquals(buffer.getRemaining(), duplicate.getRemaining()); + Assert.assertEquals(buffer.getCapacity(), duplicate.getCapacity()); + + duplicate.resetPosition(); + for (int i = 0; i < 90; i++) { + if (i < 5) { + Assert.assertEquals(0x07, duplicate.getByte()); + } else { + Assert.assertEquals(0x70, duplicate.getByte()); + } + } + + Buffer region = buffer.region(5L, 105L, buffer.getTypeByteOrder()); + Assert.assertEquals(0, region.getStart()); + Assert.assertEquals(0, region.getPosition()); + Assert.assertEquals(105L, region.getEnd()); + Assert.assertEquals(105L, region.getRemaining()); + Assert.assertEquals(105L, region.getCapacity()); + + for (int i = 0; i < 105; i++) { + if (i < 10) { + Assert.assertEquals(0x07, region.getByte()); + } else if (i < 95) { + Assert.assertEquals(0x70, region.getByte()); + } else { + Assert.assertEquals(0x07, region.getByte()); + } + } + } + + @Test + public void testFill() + { + WritableBuffer buffer = getBuffer(); + WritableBuffer anotherBuffer = getBuffer(); + + buffer.fill((byte) 0x0F); + anotherBuffer.fill((byte) 0x0F); + Assert.assertTrue(buffer.equalTo(0L, anotherBuffer, 0L, CAPACITY)); + + anotherBuffer.setPosition(100); + anotherBuffer.clear(); + Assert.assertFalse(buffer.equalTo(0L, anotherBuffer, 0L, CAPACITY)); + Assert.assertTrue(buffer.equalTo(0L, anotherBuffer, 0L, 100L)); + } + + private WritableBuffer getBuffer() + { + return getBuffer(CAPACITY); + } + + private WritableBuffer getBuffer(int capacity) + { + final ByteBuffer aBuffer = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN); + SafeWritableBuffer memory = new SafeWritableBuffer(aBuffer); + return memory; + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/data/SafeWritableMemoryTest.java b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableMemoryTest.java new file mode 100644 index 00000000000..786443f43ed --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableMemoryTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.memory.WritableMemory; +import org.apache.datasketches.memory.internal.UnsafeUtil; +import org.junit.Assert; +import org.junit.Test; + +import java.io.CharArrayWriter; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class SafeWritableMemoryTest +{ + private static final int CAPACITY = 1024; + + @Test + public void testPutAndGet() + { + final WritableMemory memory = getMemory(); + memory.putByte(3L, (byte) 0x01); + Assert.assertEquals(memory.getByte(3L), 0x01); + + memory.putBoolean(1L, true); + Assert.assertTrue(memory.getBoolean(1L)); + memory.putBoolean(1L, false); + Assert.assertFalse(memory.getBoolean(1L)); + + memory.putChar(10L, 'c'); + Assert.assertEquals('c', memory.getChar(10L)); + + memory.putDouble(14L, 3.3); + Assert.assertEquals(3.3, memory.getDouble(14L), 0.0); + + memory.putFloat(27L, 3.3f); + Assert.assertEquals(3.3f, memory.getFloat(27L), 0.0); + + memory.putInt(11L, 1234); + Assert.assertEquals(1234, memory.getInt(11L)); + + memory.putLong(500L, 500L); + Assert.assertEquals(500L, memory.getLong(500L)); + + memory.putShort(11L, (short) 15); + Assert.assertEquals(15, memory.getShort(11L)); + + long l = memory.getAndSetLong(900L, 10L); + Assert.assertEquals(0L, l); + l = memory.getAndSetLong(900L, 100L); + Assert.assertEquals(10L, l); + l = memory.getAndAddLong(900L, 10L); + Assert.assertEquals(100L, l); + Assert.assertEquals(110L, memory.getLong(900L)); + Assert.assertTrue(memory.compareAndSwapLong(900L, 110L, 120L)); + Assert.assertFalse(memory.compareAndSwapLong(900L, 110L, 120L)); + Assert.assertEquals(120L, memory.getLong(900L)); + } + + @Test + public void testPutAndGetArrays() + { + final WritableMemory memory = getMemory(); + final byte[] b1 = new byte[]{0x01, 0x02, 0x08, 0x08}; + final byte[] b2 = new byte[b1.length]; + memory.putByteArray(12L, b1, 0, 3); + memory.putByteArray(15L, b1, 3, 1); + memory.getByteArray(12L, b2, 0, 3); + memory.getByteArray(15L, b2, 3, 1); + Assert.assertArrayEquals(b1, b2); + + final boolean[] bool1 = new boolean[]{true, false, false, true}; + final boolean[] bool2 = new boolean[bool1.length]; + memory.putBooleanArray(100L, bool1, 0, 2); + memory.putBooleanArray(102L, bool1, 2, 2); + memory.getBooleanArray(100L, bool2, 0, 2); + memory.getBooleanArray(102L, bool2, 2, 2); + Assert.assertArrayEquals(bool1, bool2); + + final char[] chars1 = new char[]{'a', 'b', 'c', 'd'}; + final char[] chars2 = new char[chars1.length]; + memory.putCharArray(10L, chars1, 0, 4); + memory.getCharArray(10L, chars2, 0, chars1.length); + Assert.assertArrayEquals(chars1, chars2); + + final double[] double1 = new double[]{1.1, -2.2, 3.3, 4.4}; + final double[] double2 = new double[double1.length]; + memory.putDoubleArray(100L, double1, 0, 1); + memory.putDoubleArray(100L + Double.BYTES, double1, 1, 3); + memory.getDoubleArray(100L, double2, 0, 2); + memory.getDoubleArray(100L + (2 * Double.BYTES), double2, 2, 2); + for (int i = 0; i < double1.length; i++) { + Assert.assertEquals(double1[i], double2[i], 0.0); + } + + final float[] float1 = new float[]{1.1f, 2.2f, -3.3f, 4.4f}; + final float[] float2 = new float[float1.length]; + memory.putFloatArray(100L, float1, 0, 1); + memory.putFloatArray(100L + Float.BYTES, float1, 1, 3); + memory.getFloatArray(100L, float2, 0, 2); + memory.getFloatArray(100L + (2 * Float.BYTES), float2, 2, 2); + for (int i = 0; i < float1.length; i++) { + Assert.assertEquals(float1[i], float2[i], 0.0); + } + + final int[] ints1 = new int[]{1, 2, -3, 4}; + final int[] ints2 = new int[ints1.length]; + memory.putIntArray(100L, ints1, 0, 1); + memory.putIntArray(100L + Integer.BYTES, ints1, 1, 3); + memory.getIntArray(100L, ints2, 0, 2); + memory.getIntArray(100L + (2 * Integer.BYTES), ints2, 2, 2); + Assert.assertArrayEquals(ints1, ints2); + + final long[] longs1 = new long[]{1L, -2L, 3L, -14L}; + final long[] longs2 = new long[ints1.length]; + memory.putLongArray(100L, longs1, 0, 1); + memory.putLongArray(100L + Long.BYTES, longs1, 1, 3); + memory.getLongArray(100L, longs2, 0, 2); + memory.getLongArray(100L + (2 * Long.BYTES), longs2, 2, 2); + Assert.assertArrayEquals(longs1, longs2); + + final short[] shorts1 = new short[]{1, -2, 3, -14}; + final short[] shorts2 = new short[ints1.length]; + memory.putShortArray(100L, shorts1, 0, 1); + memory.putShortArray(100L + Short.BYTES, shorts1, 1, 3); + memory.getShortArray(100L, shorts2, 0, 2); + memory.getShortArray(100L + (2 * Short.BYTES), shorts2, 2, 2); + Assert.assertArrayEquals(shorts1, shorts2); + } + + @Test + public void testFill() + { + final byte theByte = 0x01; + final byte anotherByte = 0x02; + final WritableMemory memory = getMemory(); + final int halfWay = (int) (memory.getCapacity() / 2); + + memory.fill(theByte); + for (int i = 0; i < memory.getCapacity(); i++) { + Assert.assertEquals(theByte, memory.getByte(i)); + } + + memory.fill(halfWay, memory.getCapacity() - halfWay, anotherByte); + for (int i = 0; i < memory.getCapacity(); i++) { + if (i < halfWay) { + Assert.assertEquals(theByte, memory.getByte(i)); + } else { + Assert.assertEquals(anotherByte, memory.getByte(i)); + } + } + + memory.clear(halfWay, memory.getCapacity() - halfWay); + for (int i = 0; i < memory.getCapacity(); i++) { + if (i < halfWay) { + Assert.assertEquals(theByte, memory.getByte(i)); + } else { + Assert.assertEquals(0, memory.getByte(i)); + } + } + + memory.setBits(halfWay - 1, anotherByte); + Assert.assertEquals(0x03, memory.getByte(halfWay - 1)); + memory.clearBits(halfWay - 1, theByte); + Assert.assertEquals(anotherByte, memory.getByte(halfWay - 1)); + + memory.clear(); + for (int i = 0; i < memory.getCapacity(); i++) { + Assert.assertEquals(0, memory.getByte(i)); + } + } + + @Test + public void testStringStuff() throws IOException + { + WritableMemory memory = getMemory(); + String s1 = "hello "; + memory.putCharsToUtf8(10L, s1); + + StringBuilder builder = new StringBuilder(); + memory.getCharsFromUtf8(10L, s1.length(), builder); + Assert.assertEquals(s1, builder.toString()); + + CharArrayWriter someAppendable = new CharArrayWriter(); + memory.getCharsFromUtf8(10L, s1.length(), someAppendable); + Assert.assertEquals(s1, someAppendable.toString()); + } + + @Test + public void testRegion() + { + WritableMemory memory = getMemory(); + Assert.assertEquals(CAPACITY, memory.getCapacity()); + Assert.assertEquals(0, memory.getCumulativeOffset()); + Assert.assertEquals(10L, memory.getCumulativeOffset(10L)); + Assert.assertThrows( + IllegalArgumentException.class, + () -> memory.checkValidAndBounds(CAPACITY - 10, 11L) + ); + + final byte[] someBytes = new byte[]{0x01, 0x02, 0x03, 0x04}; + memory.putByteArray(10L, someBytes, 0, someBytes.length); + + Memory region = memory.region(10L, someBytes.length); + Assert.assertEquals(someBytes.length, region.getCapacity()); + Assert.assertEquals(0, region.getCumulativeOffset()); + Assert.assertEquals(2L, region.getCumulativeOffset(2L)); + Assert.assertThrows( + IllegalArgumentException.class, + () -> region.checkValidAndBounds(2L, 4L) + ); + + final byte[] andBack = new byte[someBytes.length]; + region.getByteArray(0L, andBack, 0, someBytes.length); + Assert.assertArrayEquals(someBytes, andBack); + + Memory differentOrderRegion = memory.region(10L, someBytes.length, ByteOrder.BIG_ENDIAN); + // different order + Assert.assertFalse(region.isByteOrderCompatible(differentOrderRegion.getTypeByteOrder())); + // contents are equal tho + Assert.assertTrue(region.equalTo(0L, differentOrderRegion, 0L, someBytes.length)); + } + + @Test + public void testCompareAndEquals() + { + WritableMemory memory = getMemory(); + final byte[] someBytes = new byte[]{0x01, 0x02, 0x03, 0x04}; + final byte[] shorterSameBytes = new byte[]{0x01, 0x02, 0x03}; + final byte[] differentBytes = new byte[]{0x02, 0x02, 0x03, 0x04}; + memory.putByteArray(10L, someBytes, 0, someBytes.length); + memory.putByteArray(400L, someBytes, 0, someBytes.length); + memory.putByteArray(200L, shorterSameBytes, 0, shorterSameBytes.length); + memory.putByteArray(500L, differentBytes, 0, differentBytes.length); + + Assert.assertEquals(0, memory.compareTo(10L, someBytes.length, memory, 400L, someBytes.length)); + Assert.assertEquals(4, memory.compareTo(10L, someBytes.length, memory, 200L, someBytes.length)); + Assert.assertEquals(-1, memory.compareTo(10L, someBytes.length, memory, 500L, differentBytes.length)); + + WritableMemory memory2 = getMemory(); + memory2.putByteArray(0L, someBytes, 0, someBytes.length); + + Assert.assertEquals(0, memory.compareTo(10L, someBytes.length, memory2, 0L, someBytes.length)); + + Assert.assertTrue(memory.equalTo(10L, memory2, 0L, someBytes.length)); + + WritableMemory memory3 = getMemory(); + memory2.copyTo(0L, memory3, 0L, CAPACITY); + Assert.assertTrue(memory2.equalTo(0L, memory3, 0L, CAPACITY)); + } + + @Test + public void testHash() + { + WritableMemory memory = getMemory(); + final long[] someLongs = new long[]{1L, 10L, 100L, 1000L, 10000L}; + final int[] someInts = new int[]{1, 2, 3}; + final byte[] someBytes = new byte[]{0x01, 0x02, 0x03}; + final int longsLength = Long.BYTES * someLongs.length; + final int someIntsLength = Integer.BYTES * someInts.length; + final int totalLength = longsLength + someIntsLength + someBytes.length; + memory.putLongArray(2L, someLongs, 0, someLongs.length); + memory.putIntArray(2L + longsLength, someInts, 0, someInts.length); + memory.putByteArray(2L + longsLength + someIntsLength, someBytes, 0, someBytes.length); + Memory memory2 = Memory.wrap(memory.getByteBuffer(), ByteOrder.LITTLE_ENDIAN); + Assert.assertEquals( + memory2.xxHash64(2L, totalLength, 0), + memory.xxHash64(2L, totalLength, 0) + ); + + Assert.assertEquals( + memory2.xxHash64(2L, 0), + memory.xxHash64(2L, 0) + ); + } + + @Test + public void testToHexString() + { + + final byte[] bytes = new byte[]{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}; + final WritableMemory memory = getMemory(bytes.length); + memory.putByteArray(0L, bytes, 0, bytes.length); + final long hcode = memory.hashCode() & 0XFFFFFFFFL; + final long bufferhcode = memory.getByteBuffer().hashCode() & 0XFFFFFFFFL; + final long reqhcode = memory.getMemoryRequestServer().hashCode() & 0XFFFFFFFFL; + Assert.assertEquals( + "### SafeWritableMemory SUMMARY ###\n" + + "Header Comment : test memory dump\n" + + "Call Parameters : .toHexString(..., 0, 8), hashCode: " + hcode + "\n" + + "UnsafeObj, hashCode : null\n" + + "UnsafeObjHeader : 0\n" + + "ByteBuf, hashCode : HeapByteBuffer, " + bufferhcode + "\n" + + "RegionOffset : 0\n" + + "Capacity : 8\n" + + "CumBaseOffset : 0\n" + + "MemReq, hashCode : HeapByteBufferMemoryRequestServer, " + reqhcode + "\n" + + "Valid : true\n" + + "Read Only : false\n" + + "Type Byte Order : LITTLE_ENDIAN\n" + + "Native Byte Order : LITTLE_ENDIAN\n" + + "JDK Runtime Version : " + UnsafeUtil.JDK + "\n" + + "Data, littleEndian : 0 1 2 3 4 5 6 7\n" + + " 0: 00 01 02 03 04 05 06 07 \n", + memory.toHexString("test memory dump", 0, bytes.length) + ); + } + + @Test + public void testMisc() + { + WritableMemory memory = getMemory(10); + WritableMemory memory2 = memory.getMemoryRequestServer().request(memory, 20); + Assert.assertEquals(20, memory2.getCapacity()); + + Assert.assertFalse(memory2.hasArray()); + + Assert.assertFalse(memory2.isReadOnly()); + Assert.assertFalse(memory2.isDirect()); + Assert.assertTrue(memory2.isValid()); + Assert.assertTrue(memory2.hasByteBuffer()); + + Assert.assertFalse(memory2.isSameResource(memory)); + Assert.assertTrue(memory2.isSameResource(memory2)); + + // does nothing + memory.getMemoryRequestServer().requestClose(memory, memory2); + } + + private WritableMemory getMemory() + { + return getMemory(CAPACITY); + } + + private WritableMemory getMemory(int capacity) + { + final ByteBuffer aBuffer = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN); + return SafeWritableMemory.wrap(aBuffer); + } +}