Add mechanism for 'safe' memory reads for complex types (#13361)

* we can read where we want to
we can leave your bounds behind
'cause if the memory is not there
we really don't care
and we'll crash this process of mine
This commit is contained in:
Clint Wylie 2022-11-23 00:25:22 -08:00 committed by GitHub
parent c26b18c953
commit f524c68f08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 2796 additions and 11 deletions

View File

@ -28,6 +28,7 @@ import org.apache.druid.segment.GenericColumnSerializer;
import org.apache.druid.segment.column.ColumnBuilder; import org.apache.druid.segment.column.ColumnBuilder;
import org.apache.druid.segment.data.GenericIndexed; import org.apache.druid.segment.data.GenericIndexed;
import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.data.SafeWritableMemory;
import org.apache.druid.segment.serde.ComplexColumnPartSupplier; import org.apache.druid.segment.serde.ComplexColumnPartSupplier;
import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.apache.druid.segment.serde.ComplexMetricSerde; import org.apache.druid.segment.serde.ComplexMetricSerde;
@ -70,7 +71,7 @@ public class HllSketchMergeComplexMetricSerde extends ComplexMetricSerde
if (object == null) { if (object == null) {
return null; return null;
} }
return deserializeSketch(object); return deserializeSketchSafe(object);
} }
}; };
} }
@ -98,6 +99,18 @@ public class HllSketchMergeComplexMetricSerde extends ComplexMetricSerde
throw new IAE("Object is not of a type that can be deserialized to an HllSketch:" + object.getClass().getName()); throw new IAE("Object is not of a type that can be deserialized to an HllSketch:" + object.getClass().getName());
} }
static HllSketch deserializeSketchSafe(final Object object)
{
if (object instanceof String) {
return HllSketch.wrap(SafeWritableMemory.wrap(StringUtils.decodeBase64(((String) object).getBytes(StandardCharsets.UTF_8))));
} else if (object instanceof byte[]) {
return HllSketch.wrap(SafeWritableMemory.wrap((byte[]) object));
} else if (object instanceof HllSketch) {
return (HllSketch) object;
}
throw new IAE("Object is not of a type that can be deserialized to an HllSketch:" + object.getClass().getName());
}
// support large columns // support large columns
@Override @Override
public GenericColumnSerializer getSerializer(final SegmentWriteOutMedium segmentWriteOutMedium, final String column) public GenericColumnSerializer getSerializer(final SegmentWriteOutMedium segmentWriteOutMedium, final String column)

View File

@ -22,7 +22,9 @@ package org.apache.druid.query.aggregation.datasketches.hll;
import org.apache.datasketches.hll.HllSketch; import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
@ -55,4 +57,12 @@ public class HllSketchObjectStrategy implements ObjectStrategy<HllSketch>
return sketch.toCompactByteArray(); return sketch.toCompactByteArray();
} }
@Nullable
@Override
public HllSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
{
return HllSketch.wrap(
SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
);
}
} }

View File

@ -91,7 +91,7 @@ public class KllDoublesSketchComplexMetricSerde extends ComplexMetricSerde
if (object == null || object instanceof KllDoublesSketch || object instanceof Memory) { if (object == null || object instanceof KllDoublesSketch || object instanceof Memory) {
return object; return object;
} }
return KllDoublesSketchOperations.deserialize(object); return KllDoublesSketchOperations.deserializeSafe(object);
} }
}; };
} }

View File

@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.kll.KllDoublesSketch; import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
@ -60,4 +62,15 @@ public class KllDoublesSketchObjectStrategy implements ObjectStrategy<KllDoubles
return sketch.toByteArray(); return sketch.toByteArray();
} }
@Nullable
@Override
public KllDoublesSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
{
if (numBytes == 0) {
return KllDoublesSketchOperations.EMPTY_SKETCH;
}
return KllDoublesSketch.wrap(
SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
);
}
} }

View File

@ -23,6 +23,7 @@ import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -46,6 +47,16 @@ public class KllDoublesSketchOperations
); );
} }
public static KllDoublesSketch deserializeSafe(final Object serializedSketch)
{
if (serializedSketch instanceof String) {
return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
} else if (serializedSketch instanceof byte[]) {
return deserializeFromByteArraySafe((byte[]) serializedSketch);
}
return deserialize(serializedSketch);
}
public static KllDoublesSketch deserializeFromBase64EncodedString(final String str) public static KllDoublesSketch deserializeFromBase64EncodedString(final String str)
{ {
return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8))); return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@ -56,4 +67,14 @@ public class KllDoublesSketchOperations
return KllDoublesSketch.wrap(Memory.wrap(data)); return KllDoublesSketch.wrap(Memory.wrap(data));
} }
public static KllDoublesSketch deserializeFromBase64EncodedStringSafe(final String str)
{
return deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
}
public static KllDoublesSketch deserializeFromByteArraySafe(final byte[] data)
{
return KllDoublesSketch.wrap(SafeWritableMemory.wrap(data));
}
} }

View File

@ -91,7 +91,7 @@ public class KllFloatsSketchComplexMetricSerde extends ComplexMetricSerde
if (object == null || object instanceof KllFloatsSketch || object instanceof Memory) { if (object == null || object instanceof KllFloatsSketch || object instanceof Memory) {
return object; return object;
} }
return KllFloatsSketchOperations.deserialize(object); return KllFloatsSketchOperations.deserializeSafe(object);
} }
}; };
} }

View File

@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.kll.KllFloatsSketch; import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
@ -60,4 +62,15 @@ public class KllFloatsSketchObjectStrategy implements ObjectStrategy<KllFloatsSk
return sketch.toByteArray(); return sketch.toByteArray();
} }
@Nullable
@Override
public KllFloatsSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
{
if (numBytes == 0) {
return KllFloatsSketchOperations.EMPTY_SKETCH;
}
return KllFloatsSketch.wrap(
SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
);
}
} }

View File

@ -23,6 +23,7 @@ import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -46,6 +47,16 @@ public class KllFloatsSketchOperations
); );
} }
public static KllFloatsSketch deserializeSafe(final Object serializedSketch)
{
if (serializedSketch instanceof String) {
return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
} else if (serializedSketch instanceof byte[]) {
return deserializeFromByteArraySafe((byte[]) serializedSketch);
}
return deserialize(serializedSketch);
}
public static KllFloatsSketch deserializeFromBase64EncodedString(final String str) public static KllFloatsSketch deserializeFromBase64EncodedString(final String str)
{ {
return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8))); return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@ -56,4 +67,14 @@ public class KllFloatsSketchOperations
return KllFloatsSketch.wrap(Memory.wrap(data)); return KllFloatsSketch.wrap(Memory.wrap(data));
} }
public static KllFloatsSketch deserializeFromBase64EncodedStringSafe(final String str)
{
return deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
}
public static KllFloatsSketch deserializeFromByteArraySafe(final byte[] data)
{
return KllFloatsSketch.wrap(SafeWritableMemory.wrap(data));
}
} }

View File

@ -92,7 +92,7 @@ public class DoublesSketchComplexMetricSerde extends ComplexMetricSerde
if (object == null || object instanceof DoublesSketch || object instanceof Memory) { if (object == null || object instanceof DoublesSketch || object instanceof Memory) {
return object; return object;
} }
return DoublesSketchOperations.deserialize(object); return DoublesSketchOperations.deserializeSafe(object);
} }
}; };
} }

View File

@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.quantiles.DoublesSketch; import org.apache.datasketches.quantiles.DoublesSketch;
import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
@ -60,4 +62,15 @@ public class DoublesSketchObjectStrategy implements ObjectStrategy<DoublesSketch
return sketch.toByteArray(true); return sketch.toByteArray(true);
} }
@Nullable
@Override
public DoublesSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
{
if (numBytes == 0) {
return DoublesSketchOperations.EMPTY_SKETCH;
}
return DoublesSketch.wrap(
SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
);
}
} }

View File

@ -23,6 +23,7 @@ import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.quantiles.DoublesSketch; import org.apache.datasketches.quantiles.DoublesSketch;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -46,6 +47,16 @@ public class DoublesSketchOperations
); );
} }
public static DoublesSketch deserializeSafe(final Object serializedSketch)
{
if (serializedSketch instanceof String) {
return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
} else if (serializedSketch instanceof byte[]) {
return deserializeFromByteArraySafe((byte[]) serializedSketch);
}
return deserialize(serializedSketch);
}
public static DoublesSketch deserializeFromBase64EncodedString(final String str) public static DoublesSketch deserializeFromBase64EncodedString(final String str)
{ {
return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8))); return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@ -56,4 +67,13 @@ public class DoublesSketchOperations
return DoublesSketch.wrap(Memory.wrap(data)); return DoublesSketch.wrap(Memory.wrap(data));
} }
public static DoublesSketch deserializeFromBase64EncodedStringSafe(final String str)
{
return deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
}
public static DoublesSketch deserializeFromByteArraySafe(final byte[] data)
{
return DoublesSketch.wrap(SafeWritableMemory.wrap(data));
}
} }

View File

@ -51,7 +51,7 @@ public class SketchConstantPostAggregator implements PostAggregator
Preconditions.checkArgument(value != null && !value.isEmpty(), Preconditions.checkArgument(value != null && !value.isEmpty(),
"Constant value cannot be null or empty, expecting base64 encoded sketch string"); "Constant value cannot be null or empty, expecting base64 encoded sketch string");
this.value = value; this.value = value;
this.sketchValue = SketchHolder.deserialize(value); this.sketchValue = SketchHolder.deserializeSafe(value);
} }
@Override @Override

View File

@ -34,6 +34,7 @@ import org.apache.datasketches.theta.Union;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -224,6 +225,17 @@ public class SketchHolder
); );
} }
public static SketchHolder deserializeSafe(Object serializedSketch)
{
if (serializedSketch instanceof String) {
return SketchHolder.of(deserializeFromBase64EncodedStringSafe((String) serializedSketch));
} else if (serializedSketch instanceof byte[]) {
return SketchHolder.of(deserializeFromByteArraySafe((byte[]) serializedSketch));
}
return deserialize(serializedSketch);
}
private static Sketch deserializeFromBase64EncodedString(String str) private static Sketch deserializeFromBase64EncodedString(String str)
{ {
return deserializeFromByteArray(StringUtils.decodeBase64(StringUtils.toUtf8(str))); return deserializeFromByteArray(StringUtils.decodeBase64(StringUtils.toUtf8(str)));
@ -234,6 +246,16 @@ public class SketchHolder
return deserializeFromMemory(Memory.wrap(data)); return deserializeFromMemory(Memory.wrap(data));
} }
private static Sketch deserializeFromBase64EncodedStringSafe(String str)
{
return deserializeFromByteArraySafe(StringUtils.decodeBase64(StringUtils.toUtf8(str)));
}
private static Sketch deserializeFromByteArraySafe(byte[] data)
{
return deserializeFromMemory(SafeWritableMemory.wrap(data));
}
private static Sketch deserializeFromMemory(Memory mem) private static Sketch deserializeFromMemory(Memory mem)
{ {
if (Sketch.getSerializationVersion(mem) < 3) { if (Sketch.getSerializationVersion(mem) < 3) {

View File

@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.theta.Sketch; import org.apache.datasketches.theta.Sketch;
import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -66,4 +67,17 @@ public class SketchHolderObjectStrategy implements ObjectStrategy<SketchHolder>
return ByteArrays.EMPTY_ARRAY; return ByteArrays.EMPTY_ARRAY;
} }
} }
@Nullable
@Override
public SketchHolder fromByteBufferSafe(ByteBuffer buffer, int numBytes)
{
if (numBytes == 0) {
return SketchHolder.EMPTY;
}
return SketchHolder.of(
SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
);
}
} }

View File

@ -59,7 +59,7 @@ public class SketchMergeComplexMetricSerde extends ComplexMetricSerde
public SketchHolder extractValue(InputRow inputRow, String metricName) public SketchHolder extractValue(InputRow inputRow, String metricName)
{ {
final Object object = inputRow.getRaw(metricName); final Object object = inputRow.getRaw(metricName);
return object == null ? null : SketchHolder.deserialize(object); return object == null ? null : SketchHolder.deserializeSafe(object);
} }
}; };
} }

View File

@ -60,7 +60,7 @@ public class ArrayOfDoublesSketchMergeComplexMetricSerde extends ComplexMetricSe
if (object == null || object instanceof ArrayOfDoublesSketch) { if (object == null || object instanceof ArrayOfDoublesSketch) {
return object; return object;
} }
return ArrayOfDoublesSketchOperations.deserialize(object); return ArrayOfDoublesSketchOperations.deserializeSafe(object);
} }
}; };
} }

View File

@ -23,6 +23,7 @@ import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketch; import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketch;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketches; import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketches;
import org.apache.druid.segment.data.ObjectStrategy; import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -48,7 +49,9 @@ public class ArrayOfDoublesSketchObjectStrategy implements ObjectStrategy<ArrayO
@Override @Override
public ArrayOfDoublesSketch fromByteBuffer(final ByteBuffer buffer, final int numBytes) public ArrayOfDoublesSketch fromByteBuffer(final ByteBuffer buffer, final int numBytes)
{ {
return ArrayOfDoublesSketches.wrapSketch(Memory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)); return ArrayOfDoublesSketches.wrapSketch(
Memory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
);
} }
@Override @Override
@ -61,4 +64,12 @@ public class ArrayOfDoublesSketchObjectStrategy implements ObjectStrategy<ArrayO
return sketch.toByteArray(); return sketch.toByteArray();
} }
@Nullable
@Override
public ArrayOfDoublesSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
{
return ArrayOfDoublesSketches.wrapSketch(
SafeWritableMemory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
);
}
} }

View File

@ -30,6 +30,7 @@ import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUnion;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -115,6 +116,17 @@ public class ArrayOfDoublesSketchOperations
throw new ISE("Object is not of a type that can deserialize to sketch: %s", serializedSketch.getClass()); throw new ISE("Object is not of a type that can deserialize to sketch: %s", serializedSketch.getClass());
} }
public static ArrayOfDoublesSketch deserializeSafe(final Object serializedSketch)
{
if (serializedSketch instanceof String) {
return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
} else if (serializedSketch instanceof byte[]) {
return deserializeFromByteArraySafe((byte[]) serializedSketch);
}
return deserialize(serializedSketch);
}
public static ArrayOfDoublesSketch deserializeFromBase64EncodedString(final String str) public static ArrayOfDoublesSketch deserializeFromBase64EncodedString(final String str)
{ {
return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8))); return deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@ -122,8 +134,16 @@ public class ArrayOfDoublesSketchOperations
public static ArrayOfDoublesSketch deserializeFromByteArray(final byte[] data) public static ArrayOfDoublesSketch deserializeFromByteArray(final byte[] data)
{ {
final Memory mem = Memory.wrap(data); return ArrayOfDoublesSketches.wrapSketch(Memory.wrap(data));
return ArrayOfDoublesSketches.wrapSketch(mem);
} }
public static ArrayOfDoublesSketch deserializeFromBase64EncodedStringSafe(final String str)
{
return deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
}
public static ArrayOfDoublesSketch deserializeFromByteArraySafe(final byte[] data)
{
return ArrayOfDoublesSketches.wrapSketch(SafeWritableMemory.wrap(data));
}
} }

View File

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.hll;
import org.apache.datasketches.SketchesArgumentException;
import org.apache.datasketches.hll.HllSketch;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class HllSketchObjectStrategyTest
{
@Test
public void testSafeRead()
{
HllSketch sketch = new HllSketch();
sketch.update(new int[]{1, 2, 3});
final byte[] bytes = sketch.toCompactByteArray();
ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
HllSketchObjectStrategy objectStrategy = new HllSketchObjectStrategy();
// valid sketch should not explode when copied, which reads the memory
objectStrategy.fromByteBufferSafe(buf, bytes.length).copy();
// corrupted sketch should fail with a regular java buffer exception
for (int subset = 3; subset < bytes.length - 1; subset++) {
final byte[] garbage2 = new byte[subset];
for (int i = 0; i < garbage2.length; i++) {
garbage2[i] = buf.get(i);
}
final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).copy()
);
}
// non sketch that is too short to contain header should fail with regular java buffer exception
final byte[] garbage = new byte[]{0x01, 0x02};
final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).copy()
);
// non sketch that is long enough to check (this one doesn't actually need 'safe' read)
final byte[] garbageLonger = StringUtils.toUtf8("notasketch");
final ByteBuffer buf4 = ByteBuffer.wrap(garbageLonger).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
SketchesArgumentException.class,
() -> objectStrategy.fromByteBufferSafe(buf4, garbageLonger.length).copy()
);
}
}

View File

@ -23,10 +23,14 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.datasketches.kll.KllDoublesSketch; import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.druid.data.input.MapBasedInputRow; import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class KllDoublesSketchComplexMetricSerdeTest public class KllDoublesSketchComplexMetricSerdeTest
{ {
@Test @Test
@ -92,4 +96,44 @@ public class KllDoublesSketchComplexMetricSerdeTest
Assert.assertEquals(1, sketch.getNumRetained()); Assert.assertEquals(1, sketch.getNumRetained());
Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d); Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d);
} }
@Test
public void testSafeRead()
{
final KllDoublesSketchComplexMetricSerde serde = new KllDoublesSketchComplexMetricSerde();
final ObjectStrategy<KllDoublesSketch> objectStrategy = serde.getObjectStrategy();
KllDoublesSketch sketch = KllDoublesSketch.newHeapInstance();
sketch.update(1.1);
sketch.update(1.2);
final byte[] bytes = sketch.toByteArray();
ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
// valid sketch should not explode when converted to byte array, which reads the memory
objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray();
// corrupted sketch should fail with a regular java buffer exception, not all subsets actually fail with the same
// index out of bounds exceptions, but at least this many do
for (int subset = 3; subset < 24; subset++) {
final byte[] garbage2 = new byte[subset];
for (int i = 0; i < garbage2.length; i++) {
garbage2[i] = buf.get(i);
}
final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).toByteArray()
);
}
// non sketch that is too short to contain header should fail with regular java buffer exception
final byte[] garbage = new byte[]{0x01, 0x02};
final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).toByteArray()
);
}
} }

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.kll;
import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
public class KllDoublesSketchOperationsTest
{
@Test
public void testDeserializeSafe()
{
KllDoublesSketch sketch = KllDoublesSketch.newHeapInstance();
sketch.update(1.1);
sketch.update(1.2);
final byte[] bytes = sketch.toByteArray();
final String base64 = StringUtils.encodeBase64String(bytes);
Assert.assertArrayEquals(bytes, KllDoublesSketchOperations.deserializeSafe(sketch).toByteArray());
Assert.assertArrayEquals(bytes, KllDoublesSketchOperations.deserializeSafe(bytes).toByteArray());
Assert.assertArrayEquals(bytes, KllDoublesSketchOperations.deserializeSafe(base64).toByteArray());
final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 20);
Assert.assertThrows(IndexOutOfBoundsException.class, () -> KllDoublesSketchOperations.deserializeSafe(trunacted));
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> KllDoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted))
);
}
}

View File

@ -23,10 +23,14 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.datasketches.kll.KllFloatsSketch; import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.druid.data.input.MapBasedInputRow; import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class KllFloatsSketchComplexMetricSerdeTest public class KllFloatsSketchComplexMetricSerdeTest
{ {
@Test @Test
@ -92,4 +96,44 @@ public class KllFloatsSketchComplexMetricSerdeTest
Assert.assertEquals(1, sketch.getNumRetained()); Assert.assertEquals(1, sketch.getNumRetained());
Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d); Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d);
} }
@Test
public void testSafeRead()
{
final KllFloatsSketchComplexMetricSerde serde = new KllFloatsSketchComplexMetricSerde();
final ObjectStrategy<KllFloatsSketch> objectStrategy = serde.getObjectStrategy();
KllFloatsSketch sketch = KllFloatsSketch.newHeapInstance();
sketch.update(1.1f);
sketch.update(1.2f);
final byte[] bytes = sketch.toByteArray();
ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
// valid sketch should not explode when converted to byte array, which reads the memory
objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray();
// corrupted sketch should fail with a regular java buffer exception, not all subsets actually fail with the same
// index out of bounds exceptions, but at least this many do
for (int subset = 3; subset < 24; subset++) {
final byte[] garbage2 = new byte[subset];
for (int i = 0; i < garbage2.length; i++) {
garbage2[i] = buf.get(i);
}
final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).toByteArray()
);
}
// non sketch that is too short to contain header should fail with regular java buffer exception
final byte[] garbage = new byte[]{0x01, 0x02};
final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).toByteArray()
);
}
} }

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.kll;
import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
public class KllFloatsSketchOperationsTest
{
@Test
public void testDeserializeSafe()
{
KllFloatsSketch sketch = KllFloatsSketch.newHeapInstance();
sketch.update(1.1f);
sketch.update(1.2f);
final byte[] bytes = sketch.toByteArray();
final String base64 = StringUtils.encodeBase64String(bytes);
Assert.assertArrayEquals(bytes, KllFloatsSketchOperations.deserializeSafe(sketch).toByteArray());
Assert.assertArrayEquals(bytes, KllFloatsSketchOperations.deserializeSafe(bytes).toByteArray());
Assert.assertArrayEquals(bytes, KllFloatsSketchOperations.deserializeSafe(base64).toByteArray());
final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 20);
Assert.assertThrows(IndexOutOfBoundsException.class, () -> KllFloatsSketchOperations.deserializeSafe(trunacted));
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> KllFloatsSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted))
);
}
}

View File

@ -22,11 +22,16 @@ package org.apache.druid.query.aggregation.datasketches.quantiles;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.datasketches.quantiles.DoublesSketch; import org.apache.datasketches.quantiles.DoublesSketch;
import org.apache.datasketches.quantiles.DoublesUnion;
import org.apache.druid.data.input.MapBasedInputRow; import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.serde.ComplexMetricExtractor; import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class DoublesSketchComplexMetricSerdeTest public class DoublesSketchComplexMetricSerdeTest
{ {
@Test @Test
@ -92,4 +97,42 @@ public class DoublesSketchComplexMetricSerdeTest
Assert.assertEquals(1, sketch.getRetainedItems()); Assert.assertEquals(1, sketch.getRetainedItems());
Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d); Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d);
} }
@Test
public void testSafeRead()
{
final DoublesSketchComplexMetricSerde serde = new DoublesSketchComplexMetricSerde();
DoublesUnion union = DoublesUnion.builder().setMaxK(1024).build();
union.update(1.1);
final byte[] bytes = union.toByteArray();
ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
ObjectStrategy<DoublesSketch> objectStrategy = serde.getObjectStrategy();
// valid sketch should not explode when copied, which reads the memory
objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray(true);
// corrupted sketch should fail with a regular java buffer exception
for (int subset = 3; subset < 15; subset++) {
final byte[] garbage2 = new byte[subset];
for (int i = 0; i < garbage2.length; i++) {
garbage2[i] = buf.get(i);
}
final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
"i " + subset,
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).toByteArray(true)
);
}
// non sketch that is too short to contain header should fail with regular java buffer exception
final byte[] garbage = new byte[]{0x01, 0x02};
final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).toByteArray(true)
);
}
} }

View File

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.quantiles;
import org.apache.datasketches.quantiles.DoublesUnion;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
public class DoublesSketchOperationsTest
{
@Test
public void testDeserializeSafe()
{
DoublesUnion union = DoublesUnion.builder().setMaxK(1024).build();
union.update(1.1);
final byte[] bytes = union.getResult().toByteArray();
final String base64 = StringUtils.encodeBase64String(bytes);
Assert.assertArrayEquals(bytes, DoublesSketchOperations.deserializeSafe(union.getResult()).toByteArray());
Assert.assertArrayEquals(bytes, DoublesSketchOperations.deserializeSafe(bytes).toByteArray());
Assert.assertArrayEquals(bytes, DoublesSketchOperations.deserializeSafe(base64).toByteArray());
final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 4);
Assert.assertThrows(IndexOutOfBoundsException.class, () -> DoublesSketchOperations.deserializeSafe(trunacted));
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> DoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64(trunacted))
);
}
}

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.theta;
import org.apache.datasketches.Family;
import org.apache.datasketches.SketchesArgumentException;
import org.apache.datasketches.theta.SetOperation;
import org.apache.datasketches.theta.Union;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class SketchHolderObjectStrategyTest
{
@Test
public void testSafeRead()
{
SketchHolderObjectStrategy objectStrategy = new SketchHolderObjectStrategy();
Union union = (Union) SetOperation.builder().setNominalEntries(1024).build(Family.UNION);
union.update(1234L);
final byte[] bytes = union.getResult().toByteArray();
ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
// valid sketch should not explode when copied, which reads the memory
objectStrategy.fromByteBufferSafe(buf, bytes.length).getSketch().compact().getCompactBytes();
// corrupted sketch should fail with a regular java buffer exception
for (int subset = 3; subset < bytes.length - 1; subset++) {
final byte[] garbage2 = new byte[subset];
for (int i = 0; i < garbage2.length; i++) {
garbage2[i] = buf.get(i);
}
final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).getSketch().compact().getCompactBytes()
);
}
// non sketch that is too short to contain header should fail with regular java buffer exception
final byte[] garbage = new byte[]{0x01, 0x02};
final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).getSketch().compact().getCompactBytes()
);
// non sketch that is long enough to check (this one doesn't actually need 'safe' read)
final byte[] garbageLonger = StringUtils.toUtf8("notasketch");
final ByteBuffer buf4 = ByteBuffer.wrap(garbageLonger).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
SketchesArgumentException.class,
() -> objectStrategy.fromByteBufferSafe(buf4, garbageLonger.length).getSketch().compact().getCompactBytes()
);
}
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.theta;
import org.apache.datasketches.Family;
import org.apache.datasketches.theta.SetOperation;
import org.apache.datasketches.theta.Union;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
public class SketchHolderTest
{
@Test
public void testDeserializeSafe()
{
Union union = (Union) SetOperation.builder().setNominalEntries(1024).build(Family.UNION);
union.update(1234L);
final byte[] bytes = union.getResult().toByteArray();
final String base64 = StringUtils.encodeBase64String(bytes);
Assert.assertArrayEquals(bytes, SketchHolder.deserializeSafe(union.getResult()).getSketch().toByteArray());
Assert.assertArrayEquals(bytes, SketchHolder.deserializeSafe(bytes).getSketch().toByteArray());
Assert.assertArrayEquals(bytes, SketchHolder.deserializeSafe(base64).getSketch().toByteArray());
final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 10);
Assert.assertThrows(IndexOutOfBoundsException.class, () -> SketchHolder.deserializeSafe(trunacted));
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> SketchHolder.deserializeSafe(StringUtils.encodeBase64String(trunacted))
);
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.tuple;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketch;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketchBuilder;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class ArrayOfDoublesSketchObjectStrategyTest
{
@Test
public void testSafeRead()
{
ArrayOfDoublesSketchObjectStrategy objectStrategy = new ArrayOfDoublesSketchObjectStrategy();
ArrayOfDoublesUpdatableSketch sketch = new ArrayOfDoublesUpdatableSketchBuilder().setNominalEntries(1024)
.setNumberOfValues(4)
.build();
sketch.update(1L, new double[]{1.0, 2.0, 3.0, 4.0});
final byte[] bytes = sketch.compact().toByteArray();
ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
// valid sketch should not explode when copied, which reads the memory
objectStrategy.fromByteBufferSafe(buf, bytes.length).compact().toByteArray();
// corrupted sketch should fail with a regular java buffer exception
for (int subset = 3; subset < bytes.length - 1; subset++) {
final byte[] garbage2 = new byte[subset];
for (int i = 0; i < garbage2.length; i++) {
garbage2[i] = buf.get(i);
}
final ByteBuffer buf2 = ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).compact().toByteArray()
);
}
// non sketch that is too short to contain header should fail with regular java buffer exception
final byte[] garbage = new byte[]{0x01, 0x02};
final ByteBuffer buf3 = ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).compact().toByteArray()
);
}
}

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.datasketches.tuple;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketch;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketchBuilder;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
public class ArrayOfDoublesSketchOperationsTest
{
@Test
public void testDeserializeSafe()
{
ArrayOfDoublesSketchObjectStrategy objectStrategy = new ArrayOfDoublesSketchObjectStrategy();
ArrayOfDoublesUpdatableSketch sketch = new ArrayOfDoublesUpdatableSketchBuilder().setNominalEntries(1024)
.setNumberOfValues(4)
.build();
sketch.update(1L, new double[]{1.0, 2.0, 3.0, 4.0});
final byte[] bytes = sketch.toByteArray();
final String base64 = StringUtils.encodeBase64String(bytes);
Assert.assertArrayEquals(bytes, ArrayOfDoublesSketchOperations.deserializeSafe(sketch).toByteArray());
Assert.assertArrayEquals(bytes, ArrayOfDoublesSketchOperations.deserializeSafe(bytes).toByteArray());
Assert.assertArrayEquals(bytes, ArrayOfDoublesSketchOperations.deserializeSafe(base64).toByteArray());
final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 10);
Assert.assertThrows(IndexOutOfBoundsException.class, () -> ArrayOfDoublesSketchOperations.deserializeSafe(trunacted));
Assert.assertThrows(
IndexOutOfBoundsException.class,
() -> ArrayOfDoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted))
);
}
}

View File

@ -90,6 +90,6 @@ public class ObjectStrategyComplexTypeStrategy<T> implements TypeStrategy<T>
@Override @Override
public T fromBytes(byte[] value) public T fromBytes(byte[] value)
{ {
return objectStrategy.fromByteBuffer(ByteBuffer.wrap(value), value.length); return objectStrategy.fromByteBufferSafe(ByteBuffer.wrap(value), value.length);
} }
} }

View File

@ -79,4 +79,31 @@ public interface ObjectStrategy<T> extends Comparator<T>
out.write(bytes); out.write(bytes);
} }
} }
/**
* Convert values from their underlying byte representation, when the underlying bytes might be corrupted or
* maliciously constructed
*
* Implementations of this method <i>absolutely must never</i> perform any sun.misc.Unsafe based memory read or write
* operations from instructions contained in the data read from this buffer without first validating the data. If the
* data cannot be validated, all read and write operations from instructions in this data must be done directly with
* the {@link ByteBuffer} methods, or using {@link SafeWritableMemory} if
* {@link org.apache.datasketches.memory.Memory} is employed to materialize the value.
*
* Implementations of this method <i>may</i> change the given buffer's mark, or limit, and position.
*
* Implementations of this method <i>may not</i> store the given buffer in a field of the "deserialized" object,
* need to use {@link ByteBuffer#slice()}, {@link ByteBuffer#asReadOnlyBuffer()} or {@link ByteBuffer#duplicate()} in
* this case.
*
*
* @param buffer buffer to read value from
* @param numBytes number of bytes used to store the value, starting at buffer.position()
* @return an object created from the given byte buffer representation
*/
@Nullable
default T fromByteBufferSafe(ByteBuffer buffer, int numBytes)
{
return fromByteBuffer(buffer, numBytes);
}
} }

View File

@ -0,0 +1,450 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.data;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import org.apache.datasketches.memory.BaseState;
import org.apache.datasketches.memory.MemoryRequestServer;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.memory.internal.BaseStateImpl;
import org.apache.datasketches.memory.internal.UnsafeUtil;
import org.apache.datasketches.memory.internal.XxHash64;
import org.apache.druid.java.util.common.StringUtils;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/**
* Base class for making a regular {@link ByteBuffer} look like a {@link org.apache.datasketches.memory.Memory} or
* {@link org.apache.datasketches.memory.Buffer}. All methods delegate directly to the {@link ByteBuffer} rather
* than using 'unsafe' reads.
*
* @see SafeWritableMemory
* @see SafeWritableBuffer
*/
@SuppressWarnings("unused")
public abstract class SafeWritableBase implements BaseState
{
static final MemoryRequestServer SAFE_HEAP_REQUEST_SERVER = new HeapByteBufferMemoryRequestServer();
final ByteBuffer buffer;
public SafeWritableBase(ByteBuffer buffer)
{
this.buffer = buffer;
}
public MemoryRequestServer getMemoryRequestServer()
{
return SAFE_HEAP_REQUEST_SERVER;
}
public boolean getBoolean(long offsetBytes)
{
return getByte(Ints.checkedCast(offsetBytes)) != 0;
}
public byte getByte(long offsetBytes)
{
return buffer.get(Ints.checkedCast(offsetBytes));
}
public char getChar(long offsetBytes)
{
return buffer.getChar(Ints.checkedCast(offsetBytes));
}
public double getDouble(long offsetBytes)
{
return buffer.getDouble(Ints.checkedCast(offsetBytes));
}
public float getFloat(long offsetBytes)
{
return buffer.getFloat(Ints.checkedCast(offsetBytes));
}
public int getInt(long offsetBytes)
{
return buffer.getInt(Ints.checkedCast(offsetBytes));
}
public long getLong(long offsetBytes)
{
return buffer.getLong(Ints.checkedCast(offsetBytes));
}
public short getShort(long offsetBytes)
{
return buffer.getShort(Ints.checkedCast(offsetBytes));
}
public void putBoolean(long offsetBytes, boolean value)
{
buffer.put(Ints.checkedCast(offsetBytes), (byte) (value ? 1 : 0));
}
public void putByte(long offsetBytes, byte value)
{
buffer.put(Ints.checkedCast(offsetBytes), value);
}
public void putChar(long offsetBytes, char value)
{
buffer.putChar(Ints.checkedCast(offsetBytes), value);
}
public void putDouble(long offsetBytes, double value)
{
buffer.putDouble(Ints.checkedCast(offsetBytes), value);
}
public void putFloat(long offsetBytes, float value)
{
buffer.putFloat(Ints.checkedCast(offsetBytes), value);
}
public void putInt(long offsetBytes, int value)
{
buffer.putInt(Ints.checkedCast(offsetBytes), value);
}
public void putLong(long offsetBytes, long value)
{
buffer.putLong(Ints.checkedCast(offsetBytes), value);
}
public void putShort(long offsetBytes, short value)
{
buffer.putShort(Ints.checkedCast(offsetBytes), value);
}
@Override
public ByteOrder getTypeByteOrder()
{
return buffer.order();
}
@Override
public boolean isByteOrderCompatible(ByteOrder byteOrder)
{
return buffer.order().equals(byteOrder);
}
@Override
public ByteBuffer getByteBuffer()
{
return buffer;
}
@Override
public long getCapacity()
{
return buffer.capacity();
}
@Override
public long getCumulativeOffset()
{
return 0;
}
@Override
public long getCumulativeOffset(long offsetBytes)
{
return offsetBytes;
}
@Override
public long getRegionOffset()
{
return 0;
}
@Override
public long getRegionOffset(long offsetBytes)
{
return offsetBytes;
}
@Override
public boolean hasArray()
{
return false;
}
@Override
public long xxHash64(long offsetBytes, long lengthBytes, long seed)
{
return hash(buffer, offsetBytes, lengthBytes, seed);
}
@Override
public long xxHash64(long in, long seed)
{
return XxHash64.hash(in, seed);
}
@Override
public boolean hasByteBuffer()
{
return true;
}
@Override
public boolean isDirect()
{
return false;
}
@Override
public boolean isReadOnly()
{
return false;
}
@Override
public boolean isSameResource(Object that)
{
return this.equals(that);
}
@Override
public boolean isValid()
{
return true;
}
@Override
public void checkValidAndBounds(long offsetBytes, long lengthBytes)
{
Preconditions.checkArgument(
Ints.checkedCast(offsetBytes) < buffer.limit(),
"start offset %s is greater than buffer limit %s",
offsetBytes,
buffer.limit()
);
Preconditions.checkArgument(
Ints.checkedCast(offsetBytes + lengthBytes) < buffer.limit(),
"end offset %s is greater than buffer limit %s",
offsetBytes + lengthBytes,
buffer.limit()
);
}
/**
* Adapted from {@link BaseStateImpl#toHexString(String, long, int)}
*/
@Override
public String toHexString(String header, long offsetBytes, int lengthBytes)
{
final String klass = this.getClass().getSimpleName();
final String s1 = StringUtils.format("(..., %d, %d)", offsetBytes, lengthBytes);
final long hcode = hashCode() & 0XFFFFFFFFL;
final String call = ".toHexString" + s1 + ", hashCode: " + hcode;
String sb = "### " + klass + " SUMMARY ###" + UnsafeUtil.LS
+ "Header Comment : " + header + UnsafeUtil.LS
+ "Call Parameters : " + call;
return toHex(this, sb, offsetBytes, lengthBytes);
}
/**
* Adapted from {@link BaseStateImpl#toHex(BaseStateImpl, String, long, int)}
*/
static String toHex(
final SafeWritableBase state,
final String preamble,
final long offsetBytes,
final int lengthBytes
)
{
final String lineSeparator = UnsafeUtil.LS;
final long capacity = state.getCapacity();
UnsafeUtil.checkBounds(offsetBytes, lengthBytes, capacity);
final StringBuilder sb = new StringBuilder();
final String uObjStr;
final long uObjHeader;
uObjStr = "null";
uObjHeader = 0;
final ByteBuffer bb = state.getByteBuffer();
final String bbStr = bb == null ? "null"
: bb.getClass().getSimpleName() + ", " + (bb.hashCode() & 0XFFFFFFFFL);
final MemoryRequestServer memReqSvr = state.getMemoryRequestServer();
final String memReqStr = memReqSvr != null
? memReqSvr.getClass().getSimpleName() + ", " + (memReqSvr.hashCode() & 0XFFFFFFFFL)
: "null";
final long cumBaseOffset = state.getCumulativeOffset();
sb.append(preamble).append(lineSeparator);
sb.append("UnsafeObj, hashCode : ").append(uObjStr).append(lineSeparator);
sb.append("UnsafeObjHeader : ").append(uObjHeader).append(lineSeparator);
sb.append("ByteBuf, hashCode : ").append(bbStr).append(lineSeparator);
sb.append("RegionOffset : ").append(state.getRegionOffset()).append(lineSeparator);
sb.append("Capacity : ").append(capacity).append(lineSeparator);
sb.append("CumBaseOffset : ").append(cumBaseOffset).append(lineSeparator);
sb.append("MemReq, hashCode : ").append(memReqStr).append(lineSeparator);
sb.append("Valid : ").append(state.isValid()).append(lineSeparator);
sb.append("Read Only : ").append(state.isReadOnly()).append(lineSeparator);
sb.append("Type Byte Order : ").append(state.getTypeByteOrder()).append(lineSeparator);
sb.append("Native Byte Order : ").append(ByteOrder.nativeOrder()).append(lineSeparator);
sb.append("JDK Runtime Version : ").append(UnsafeUtil.JDK).append(lineSeparator);
//Data detail
sb.append("Data, littleEndian : 0 1 2 3 4 5 6 7");
for (long i = 0; i < lengthBytes; i++) {
final int b = state.getByte(cumBaseOffset + offsetBytes + i) & 0XFF;
if (i % 8 == 0) { //row header
sb.append(StringUtils.format("%n%20s: ", offsetBytes + i));
}
sb.append(StringUtils.format("%02x ", b));
}
sb.append(lineSeparator);
return sb.toString();
}
// copied from datasketches-memory XxHash64.java
private static final long P1 = -7046029288634856825L;
private static final long P2 = -4417276706812531889L;
private static final long P3 = 1609587929392839161L;
private static final long P4 = -8796714831421723037L;
private static final long P5 = 2870177450012600261L;
/**
* Adapted from {@link XxHash64#hash(Object, long, long, long)} to work with {@link ByteBuffer}
*/
static long hash(ByteBuffer memory, long cumOffsetBytes, final long lengthBytes, final long seed)
{
long hash;
long remaining = lengthBytes;
int offset = Ints.checkedCast(cumOffsetBytes);
if (remaining >= 32) {
long v1 = seed + P1 + P2;
long v2 = seed + P2;
long v3 = seed;
long v4 = seed - P1;
do {
v1 += memory.getLong(offset) * P2;
v1 = Long.rotateLeft(v1, 31);
v1 *= P1;
v2 += memory.getLong(offset + 8) * P2;
v2 = Long.rotateLeft(v2, 31);
v2 *= P1;
v3 += memory.getLong(offset + 16) * P2;
v3 = Long.rotateLeft(v3, 31);
v3 *= P1;
v4 += memory.getLong(offset + 24) * P2;
v4 = Long.rotateLeft(v4, 31);
v4 *= P1;
offset += 32;
remaining -= 32;
} while (remaining >= 32);
hash = Long.rotateLeft(v1, 1)
+ Long.rotateLeft(v2, 7)
+ Long.rotateLeft(v3, 12)
+ Long.rotateLeft(v4, 18);
v1 *= P2;
v1 = Long.rotateLeft(v1, 31);
v1 *= P1;
hash ^= v1;
hash = (hash * P1) + P4;
v2 *= P2;
v2 = Long.rotateLeft(v2, 31);
v2 *= P1;
hash ^= v2;
hash = (hash * P1) + P4;
v3 *= P2;
v3 = Long.rotateLeft(v3, 31);
v3 *= P1;
hash ^= v3;
hash = (hash * P1) + P4;
v4 *= P2;
v4 = Long.rotateLeft(v4, 31);
v4 *= P1;
hash ^= v4;
hash = (hash * P1) + P4;
} else { //end remaining >= 32
hash = seed + P5;
}
hash += lengthBytes;
while (remaining >= 8) {
long k1 = memory.getLong(offset);
k1 *= P2;
k1 = Long.rotateLeft(k1, 31);
k1 *= P1;
hash ^= k1;
hash = (Long.rotateLeft(hash, 27) * P1) + P4;
offset += 8;
remaining -= 8;
}
if (remaining >= 4) { //treat as unsigned ints
hash ^= (memory.getInt(offset) & 0XFFFF_FFFFL) * P1;
hash = (Long.rotateLeft(hash, 23) * P2) + P3;
offset += 4;
remaining -= 4;
}
while (remaining != 0) { //treat as unsigned bytes
hash ^= (memory.get(offset) & 0XFFL) * P5;
hash = Long.rotateLeft(hash, 11) * P1;
--remaining;
++offset;
}
hash ^= hash >>> 33;
hash *= P2;
hash ^= hash >>> 29;
hash *= P3;
hash ^= hash >>> 32;
return hash;
}
private static class HeapByteBufferMemoryRequestServer implements MemoryRequestServer
{
@Override
public WritableMemory request(WritableMemory currentWritableMemory, long capacityBytes)
{
ByteBuffer newBuffer = ByteBuffer.allocate(Ints.checkedCast(capacityBytes));
newBuffer.order(currentWritableMemory.getTypeByteOrder());
return new SafeWritableMemory(newBuffer);
}
@Override
public void requestClose(WritableMemory memToClose, WritableMemory newMemory)
{
// do nothing
}
}
}

View File

@ -0,0 +1,501 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.data;
import com.google.common.primitives.Ints;
import org.apache.datasketches.memory.BaseBuffer;
import org.apache.datasketches.memory.Buffer;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableBuffer;
import org.apache.datasketches.memory.WritableMemory;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/**
* Safety first! Don't trust something whose contents you locations to read and write stuff to, but need a
* {@link Buffer} or {@link WritableBuffer}? use this!
* <p>
* Delegates everything to an underlying {@link ByteBuffer} so all read and write operations will have bounds checks
* built in rather than using 'unsafe'.
*/
public class SafeWritableBuffer extends SafeWritableBase implements WritableBuffer
{
private int start;
private int end;
public SafeWritableBuffer(ByteBuffer buffer)
{
super(buffer);
this.start = 0;
this.buffer.position(0);
this.end = buffer.capacity();
}
@Override
public WritableBuffer writableDuplicate()
{
return writableDuplicate(buffer.order());
}
@Override
public WritableBuffer writableDuplicate(ByteOrder byteOrder)
{
ByteBuffer dupe = buffer.duplicate();
dupe.order(byteOrder);
WritableBuffer duplicate = new SafeWritableBuffer(dupe);
duplicate.setStartPositionEnd(start, buffer.position(), end);
return duplicate;
}
@Override
public WritableBuffer writableRegion()
{
ByteBuffer dupe = buffer.duplicate().order(buffer.order());
dupe.position(start);
dupe.limit(end);
ByteBuffer remaining = buffer.slice();
remaining.order(dupe.order());
return new SafeWritableBuffer(remaining);
}
@Override
public WritableBuffer writableRegion(long offsetBytes, long capacityBytes, ByteOrder byteOrder)
{
ByteBuffer dupe = buffer.duplicate();
dupe.position(Ints.checkedCast(offsetBytes));
dupe.limit(dupe.position() + Ints.checkedCast(capacityBytes));
return new SafeWritableBuffer(dupe.slice().order(byteOrder));
}
@Override
public WritableMemory asWritableMemory(ByteOrder byteOrder)
{
ByteBuffer dupe = buffer.duplicate();
dupe.order(byteOrder);
return new SafeWritableMemory(dupe);
}
@Override
public void putBoolean(boolean value)
{
buffer.put((byte) (value ? 1 : 0));
}
@Override
public void putBooleanArray(boolean[] srcArray, int srcOffsetBooleans, int lengthBooleans)
{
for (int i = 0; i < lengthBooleans; i++) {
putBoolean(srcArray[srcOffsetBooleans + i]);
}
}
@Override
public void putByte(byte value)
{
buffer.put(value);
}
@Override
public void putByteArray(byte[] srcArray, int srcOffsetBytes, int lengthBytes)
{
buffer.put(srcArray, srcOffsetBytes, lengthBytes);
}
@Override
public void putChar(char value)
{
buffer.putChar(value);
}
@Override
public void putCharArray(char[] srcArray, int srcOffsetChars, int lengthChars)
{
for (int i = 0; i < lengthChars; i++) {
buffer.putChar(srcArray[srcOffsetChars + i]);
}
}
@Override
public void putDouble(double value)
{
buffer.putDouble(value);
}
@Override
public void putDoubleArray(double[] srcArray, int srcOffsetDoubles, int lengthDoubles)
{
for (int i = 0; i < lengthDoubles; i++) {
buffer.putDouble(srcArray[srcOffsetDoubles + i]);
}
}
@Override
public void putFloat(float value)
{
buffer.putFloat(value);
}
@Override
public void putFloatArray(float[] srcArray, int srcOffsetFloats, int lengthFloats)
{
for (int i = 0; i < lengthFloats; i++) {
buffer.putFloat(srcArray[srcOffsetFloats + i]);
}
}
@Override
public void putInt(int value)
{
buffer.putInt(value);
}
@Override
public void putIntArray(int[] srcArray, int srcOffsetInts, int lengthInts)
{
for (int i = 0; i < lengthInts; i++) {
buffer.putInt(srcArray[srcOffsetInts + i]);
}
}
@Override
public void putLong(long value)
{
buffer.putLong(value);
}
@Override
public void putLongArray(long[] srcArray, int srcOffsetLongs, int lengthLongs)
{
for (int i = 0; i < lengthLongs; i++) {
buffer.putLong(srcArray[srcOffsetLongs + i]);
}
}
@Override
public void putShort(short value)
{
buffer.putShort(value);
}
@Override
public void putShortArray(short[] srcArray, int srcOffsetShorts, int lengthShorts)
{
for (int i = 0; i < lengthShorts; i++) {
buffer.putShort(srcArray[srcOffsetShorts + i]);
}
}
@Override
public Object getArray()
{
return null;
}
@Override
public void clear()
{
fill((byte) 0);
}
@Override
public void fill(byte value)
{
while (buffer.hasRemaining() && buffer.position() < end) {
buffer.put(value);
}
}
@Override
public Buffer duplicate()
{
return writableDuplicate();
}
@Override
public Buffer duplicate(ByteOrder byteOrder)
{
return writableDuplicate(byteOrder);
}
@Override
public Buffer region()
{
return writableRegion();
}
@Override
public Buffer region(long offsetBytes, long capacityBytes, ByteOrder byteOrder)
{
return writableRegion(offsetBytes, capacityBytes, byteOrder);
}
@Override
public Memory asMemory(ByteOrder byteOrder)
{
return asWritableMemory(byteOrder);
}
@Override
public boolean getBoolean()
{
return buffer.get() == 0 ? false : true;
}
@Override
public void getBooleanArray(boolean[] dstArray, int dstOffsetBooleans, int lengthBooleans)
{
for (int i = 0; i < lengthBooleans; i++) {
dstArray[dstOffsetBooleans + i] = getBoolean();
}
}
@Override
public byte getByte()
{
return buffer.get();
}
@Override
public void getByteArray(byte[] dstArray, int dstOffsetBytes, int lengthBytes)
{
for (int i = 0; i < lengthBytes; i++) {
dstArray[dstOffsetBytes + i] = buffer.get();
}
}
@Override
public char getChar()
{
return buffer.getChar();
}
@Override
public void getCharArray(char[] dstArray, int dstOffsetChars, int lengthChars)
{
for (int i = 0; i < lengthChars; i++) {
dstArray[dstOffsetChars + i] = buffer.getChar();
}
}
@Override
public double getDouble()
{
return buffer.getDouble();
}
@Override
public void getDoubleArray(double[] dstArray, int dstOffsetDoubles, int lengthDoubles)
{
for (int i = 0; i < lengthDoubles; i++) {
dstArray[dstOffsetDoubles + i] = buffer.getDouble();
}
}
@Override
public float getFloat()
{
return buffer.getFloat();
}
@Override
public void getFloatArray(float[] dstArray, int dstOffsetFloats, int lengthFloats)
{
for (int i = 0; i < lengthFloats; i++) {
dstArray[dstOffsetFloats + i] = buffer.getFloat();
}
}
@Override
public int getInt()
{
return buffer.getInt();
}
@Override
public void getIntArray(int[] dstArray, int dstOffsetInts, int lengthInts)
{
for (int i = 0; i < lengthInts; i++) {
dstArray[dstOffsetInts + i] = buffer.getInt();
}
}
@Override
public long getLong()
{
return buffer.getLong();
}
@Override
public void getLongArray(long[] dstArray, int dstOffsetLongs, int lengthLongs)
{
for (int i = 0; i < lengthLongs; i++) {
dstArray[dstOffsetLongs + i] = buffer.getLong();
}
}
@Override
public short getShort()
{
return buffer.getShort();
}
@Override
public void getShortArray(short[] dstArray, int dstOffsetShorts, int lengthShorts)
{
for (int i = 0; i < lengthShorts; i++) {
dstArray[dstOffsetShorts + i] = buffer.getShort();
}
}
@Override
public int compareTo(
long thisOffsetBytes,
long thisLengthBytes,
Buffer that,
long thatOffsetBytes,
long thatLengthBytes
)
{
final int thisLength = Ints.checkedCast(thisLengthBytes);
final int thatLength = Ints.checkedCast(thatLengthBytes);
final int commonLength = Math.min(thisLength, thatLength);
for (int i = 0; i < commonLength; i++) {
final int cmp = Byte.compare(getByte(thisOffsetBytes + i), that.getByte(thatOffsetBytes + i));
if (cmp != 0) {
return cmp;
}
}
return Integer.compare(thisLength, thatLength);
}
@Override
public BaseBuffer incrementPosition(long increment)
{
buffer.position(buffer.position() + Ints.checkedCast(increment));
return this;
}
@Override
public BaseBuffer incrementAndCheckPosition(long increment)
{
checkInvariants(start, buffer.position() + increment, end, buffer.capacity());
return incrementPosition(increment);
}
@Override
public long getEnd()
{
return end;
}
@Override
public long getPosition()
{
return buffer.position();
}
@Override
public long getStart()
{
return start;
}
@Override
public long getRemaining()
{
return buffer.remaining();
}
@Override
public boolean hasRemaining()
{
return buffer.hasRemaining();
}
@Override
public BaseBuffer resetPosition()
{
buffer.position(start);
return this;
}
@Override
public BaseBuffer setPosition(long position)
{
buffer.position(Ints.checkedCast(position));
return this;
}
@Override
public BaseBuffer setAndCheckPosition(long position)
{
checkInvariants(start, position, end, buffer.capacity());
return setPosition(position);
}
@Override
public BaseBuffer setStartPositionEnd(long start, long position, long end)
{
this.start = Ints.checkedCast(start);
this.end = Ints.checkedCast(end);
buffer.position(Ints.checkedCast(position));
buffer.limit(this.end);
return this;
}
@Override
public BaseBuffer setAndCheckStartPositionEnd(long start, long position, long end)
{
checkInvariants(start, position, end, buffer.capacity());
return setStartPositionEnd(start, position, end);
}
@Override
public boolean equalTo(long thisOffsetBytes, Object that, long thatOffsetBytes, long lengthBytes)
{
if (!(that instanceof SafeWritableBuffer)) {
return false;
}
return compareTo(thisOffsetBytes, lengthBytes, (SafeWritableBuffer) that, thatOffsetBytes, lengthBytes) == 0;
}
/**
* Adapted from {@link org.apache.datasketches.memory.internal.BaseBufferImpl#checkInvariants(long, long, long, long)}
*/
static void checkInvariants(final long start, final long pos, final long end, final long cap)
{
if ((start | pos | end | cap | (pos - start) | (end - pos) | (cap - end)) < 0L) {
throw new IllegalArgumentException(
"Violation of Invariants: "
+ "start: " + start
+ " <= pos: " + pos
+ " <= end: " + end
+ " <= cap: " + cap
+ "; (pos - start): " + (pos - start)
+ ", (end - pos): " + (end - pos)
+ ", (cap - end): " + (cap - end)
);
}
}
}

View File

@ -0,0 +1,417 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.data;
import com.google.common.primitives.Ints;
import org.apache.datasketches.memory.Buffer;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.Utf8CodingException;
import org.apache.datasketches.memory.WritableBuffer;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.java.util.common.StringUtils;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.WritableByteChannel;
/**
* Safety first! Don't trust something whose contents you locations to read and write stuff to, but need a
* {@link Memory} or {@link WritableMemory}? use this!
* <p>
* Delegates everything to an underlying {@link ByteBuffer} so all read and write operations will have bounds checks
* built in rather than using 'unsafe'.
*/
public class SafeWritableMemory extends SafeWritableBase implements WritableMemory
{
public static SafeWritableMemory wrap(byte[] bytes)
{
return wrap(ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()), 0, bytes.length);
}
public static SafeWritableMemory wrap(ByteBuffer buffer)
{
return wrap(buffer.duplicate().order(buffer.order()), 0, buffer.capacity());
}
public static SafeWritableMemory wrap(ByteBuffer buffer, ByteOrder byteOrder)
{
return wrap(buffer.duplicate().order(byteOrder), 0, buffer.capacity());
}
public static SafeWritableMemory wrap(ByteBuffer buffer, int offset, int size)
{
final ByteBuffer dupe = buffer.duplicate().order(buffer.order());
dupe.position(offset);
dupe.limit(offset + size);
return new SafeWritableMemory(dupe.slice().order(buffer.order()));
}
public SafeWritableMemory(ByteBuffer buffer)
{
super(buffer);
}
@Override
public Memory region(long offsetBytes, long capacityBytes, ByteOrder byteOrder)
{
return writableRegion(offsetBytes, capacityBytes, byteOrder);
}
@Override
public Buffer asBuffer(ByteOrder byteOrder)
{
return asWritableBuffer(byteOrder);
}
@Override
public void getBooleanArray(long offsetBytes, boolean[] dstArray, int dstOffsetBooleans, int lengthBooleans)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthBooleans; j++) {
dstArray[dstOffsetBooleans + j] = buffer.get(offset + j) != 0;
}
}
@Override
public void getByteArray(long offsetBytes, byte[] dstArray, int dstOffsetBytes, int lengthBytes)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthBytes; j++) {
dstArray[dstOffsetBytes + j] = buffer.get(offset + j);
}
}
@Override
public void getCharArray(long offsetBytes, char[] dstArray, int dstOffsetChars, int lengthChars)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthChars; j++) {
dstArray[dstOffsetChars + j] = buffer.getChar(offset + (j * Character.BYTES));
}
}
@Override
public int getCharsFromUtf8(long offsetBytes, int utf8LengthBytes, Appendable dst)
throws IOException, Utf8CodingException
{
ByteBuffer dupe = buffer.asReadOnlyBuffer().order(buffer.order());
dupe.position(Ints.checkedCast(offsetBytes));
String s = StringUtils.fromUtf8(dupe, utf8LengthBytes);
dst.append(s);
return s.length();
}
@Override
public int getCharsFromUtf8(long offsetBytes, int utf8LengthBytes, StringBuilder dst) throws Utf8CodingException
{
ByteBuffer dupe = buffer.asReadOnlyBuffer().order(buffer.order());
dupe.position(Ints.checkedCast(offsetBytes));
String s = StringUtils.fromUtf8(dupe, utf8LengthBytes);
dst.append(s);
return s.length();
}
@Override
public void getDoubleArray(long offsetBytes, double[] dstArray, int dstOffsetDoubles, int lengthDoubles)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthDoubles; j++) {
dstArray[dstOffsetDoubles + j] = buffer.getDouble(offset + (j * Double.BYTES));
}
}
@Override
public void getFloatArray(long offsetBytes, float[] dstArray, int dstOffsetFloats, int lengthFloats)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthFloats; j++) {
dstArray[dstOffsetFloats + j] = buffer.getFloat(offset + (j * Float.BYTES));
}
}
@Override
public void getIntArray(long offsetBytes, int[] dstArray, int dstOffsetInts, int lengthInts)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthInts; j++) {
dstArray[dstOffsetInts + j] = buffer.getInt(offset + (j * Integer.BYTES));
}
}
@Override
public void getLongArray(long offsetBytes, long[] dstArray, int dstOffsetLongs, int lengthLongs)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthLongs; j++) {
dstArray[dstOffsetLongs + j] = buffer.getLong(offset + (j * Long.BYTES));
}
}
@Override
public void getShortArray(long offsetBytes, short[] dstArray, int dstOffsetShorts, int lengthShorts)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int j = 0; j < lengthShorts; j++) {
dstArray[dstOffsetShorts + j] = buffer.getShort(offset + (j * Short.BYTES));
}
}
@Override
public int compareTo(
long thisOffsetBytes,
long thisLengthBytes,
Memory that,
long thatOffsetBytes,
long thatLengthBytes
)
{
final int thisLength = Ints.checkedCast(thisLengthBytes);
final int thatLength = Ints.checkedCast(thatLengthBytes);
final int commonLength = Math.min(thisLength, thatLength);
for (int i = 0; i < commonLength; i++) {
final int cmp = Byte.compare(getByte(thisOffsetBytes + i), that.getByte(thatOffsetBytes + i));
if (cmp != 0) {
return cmp;
}
}
return Integer.compare(thisLength, thatLength);
}
@Override
public void copyTo(long srcOffsetBytes, WritableMemory destination, long dstOffsetBytes, long lengthBytes)
{
int offset = Ints.checkedCast(srcOffsetBytes);
for (int i = 0; i < lengthBytes; i++) {
destination.putByte(dstOffsetBytes + i, buffer.get(offset + i));
}
}
@Override
public void writeTo(long offsetBytes, long lengthBytes, WritableByteChannel out) throws IOException
{
ByteBuffer dupe = buffer.duplicate();
dupe.position(Ints.checkedCast(offsetBytes));
dupe.limit(dupe.position() + Ints.checkedCast(lengthBytes));
ByteBuffer view = dupe.slice();
view.order(buffer.order());
out.write(view);
}
@Override
public boolean equalTo(long thisOffsetBytes, Object that, long thatOffsetBytes, long lengthBytes)
{
if (!(that instanceof SafeWritableMemory)) {
return false;
}
return compareTo(thisOffsetBytes, lengthBytes, (SafeWritableMemory) that, thatOffsetBytes, lengthBytes) == 0;
}
@Override
public WritableMemory writableRegion(long offsetBytes, long capacityBytes, ByteOrder byteOrder)
{
final ByteBuffer dupe = buffer.duplicate().order(buffer.order());
final int sizeBytes = Ints.checkedCast(capacityBytes);
dupe.position(Ints.checkedCast(offsetBytes));
dupe.limit(dupe.position() + sizeBytes);
final ByteBuffer view = dupe.slice();
view.order(byteOrder);
return new SafeWritableMemory(view);
}
@Override
public WritableBuffer asWritableBuffer(ByteOrder byteOrder)
{
return new SafeWritableBuffer(buffer.duplicate().order(byteOrder));
}
@Override
public void putBooleanArray(long offsetBytes, boolean[] srcArray, int srcOffsetBooleans, int lengthBooleans)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthBooleans; i++) {
buffer.put(offset + i, (byte) (srcArray[i + srcOffsetBooleans] ? 1 : 0));
}
}
@Override
public void putByteArray(long offsetBytes, byte[] srcArray, int srcOffsetBytes, int lengthBytes)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthBytes; i++) {
buffer.put(offset + i, srcArray[srcOffsetBytes + i]);
}
}
@Override
public void putCharArray(long offsetBytes, char[] srcArray, int srcOffsetChars, int lengthChars)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthChars; i++) {
buffer.putChar(offset + (i * Character.BYTES), srcArray[srcOffsetChars + i]);
}
}
@Override
public long putCharsToUtf8(long offsetBytes, CharSequence src)
{
final byte[] bytes = StringUtils.toUtf8(src.toString());
putByteArray(offsetBytes, bytes, 0, bytes.length);
return bytes.length;
}
@Override
public void putDoubleArray(long offsetBytes, double[] srcArray, int srcOffsetDoubles, int lengthDoubles)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthDoubles; i++) {
buffer.putDouble(offset + (i * Double.BYTES), srcArray[srcOffsetDoubles + i]);
}
}
@Override
public void putFloatArray(long offsetBytes, float[] srcArray, int srcOffsetFloats, int lengthFloats)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthFloats; i++) {
buffer.putFloat(offset + (i * Float.BYTES), srcArray[srcOffsetFloats + i]);
}
}
@Override
public void putIntArray(long offsetBytes, int[] srcArray, int srcOffsetInts, int lengthInts)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthInts; i++) {
buffer.putInt(offset + (i * Integer.BYTES), srcArray[srcOffsetInts + i]);
}
}
@Override
public void putLongArray(long offsetBytes, long[] srcArray, int srcOffsetLongs, int lengthLongs)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthLongs; i++) {
buffer.putLong(offset + (i * Long.BYTES), srcArray[srcOffsetLongs + i]);
}
}
@Override
public void putShortArray(long offsetBytes, short[] srcArray, int srcOffsetShorts, int lengthShorts)
{
final int offset = Ints.checkedCast(offsetBytes);
for (int i = 0; i < lengthShorts; i++) {
buffer.putShort(offset + (i * Short.BYTES), srcArray[srcOffsetShorts + i]);
}
}
@Override
public long getAndAddLong(long offsetBytes, long delta)
{
final int offset = Ints.checkedCast(offsetBytes);
final long currentValue;
synchronized (buffer) {
currentValue = buffer.getLong(offset);
buffer.putLong(offset, currentValue + delta);
}
return currentValue;
}
@Override
public boolean compareAndSwapLong(long offsetBytes, long expect, long update)
{
final int offset = Ints.checkedCast(offsetBytes);
synchronized (buffer) {
final long actual = buffer.getLong(offset);
if (expect == actual) {
buffer.putLong(offset, update);
return true;
}
}
return false;
}
@Override
public long getAndSetLong(long offsetBytes, long newValue)
{
int offset = Ints.checkedCast(offsetBytes);
synchronized (buffer) {
long l = buffer.getLong(offset);
buffer.putLong(offset, newValue);
return l;
}
}
@Override
public Object getArray()
{
return null;
}
@Override
public void clear()
{
fill((byte) 0);
}
@Override
public void clear(long offsetBytes, long lengthBytes)
{
fill(offsetBytes, lengthBytes, (byte) 0);
}
@Override
public void clearBits(long offsetBytes, byte bitMask)
{
final int offset = Ints.checkedCast(offsetBytes);
int value = buffer.get(offset) & 0XFF;
value &= ~bitMask;
buffer.put(offset, (byte) value);
}
@Override
public void fill(byte value)
{
for (int i = 0; i < buffer.capacity(); i++) {
buffer.put(i, value);
}
}
@Override
public void fill(long offsetBytes, long lengthBytes, byte value)
{
int offset = Ints.checkedCast(offsetBytes);
int length = Ints.checkedCast(lengthBytes);
for (int i = 0; i < length; i++) {
buffer.put(offset + i, value);
}
}
@Override
public void setBits(long offsetBytes, byte bitMask)
{
final int offset = Ints.checkedCast(offsetBytes);
buffer.put(offset, (byte) (buffer.get(offset) | bitMask));
}
}

View File

@ -0,0 +1,224 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.data;
import org.apache.datasketches.memory.Buffer;
import org.apache.datasketches.memory.WritableBuffer;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class SafeWritableBufferTest
{
private static final int CAPACITY = 1024;
@Test
public void testPutAndGet()
{
WritableBuffer b1 = getBuffer();
Assert.assertEquals(0, b1.getPosition());
b1.putByte((byte) 0x01);
Assert.assertEquals(1, b1.getPosition());
b1.putBoolean(true);
Assert.assertEquals(2, b1.getPosition());
b1.putBoolean(false);
Assert.assertEquals(3, b1.getPosition());
b1.putChar('c');
Assert.assertEquals(5, b1.getPosition());
b1.putDouble(1.1);
Assert.assertEquals(13, b1.getPosition());
b1.putFloat(1.1f);
Assert.assertEquals(17, b1.getPosition());
b1.putInt(100);
Assert.assertEquals(21, b1.getPosition());
b1.putLong(1000L);
Assert.assertEquals(29, b1.getPosition());
b1.putShort((short) 15);
Assert.assertEquals(31, b1.getPosition());
b1.resetPosition();
Assert.assertEquals(0x01, b1.getByte());
Assert.assertTrue(b1.getBoolean());
Assert.assertFalse(b1.getBoolean());
Assert.assertEquals('c', b1.getChar());
Assert.assertEquals(1.1, b1.getDouble(), 0.0);
Assert.assertEquals(1.1f, b1.getFloat(), 0.0);
Assert.assertEquals(100, b1.getInt());
Assert.assertEquals(1000L, b1.getLong());
Assert.assertEquals(15, b1.getShort());
}
@Test
public void testPutAndGetArrays()
{
WritableBuffer buffer = getBuffer();
final byte[] b1 = new byte[]{0x01, 0x02, 0x08, 0x08};
final byte[] b2 = new byte[b1.length];
final boolean[] bool1 = new boolean[]{true, false, false, true};
final boolean[] bool2 = new boolean[bool1.length];
final char[] chars1 = new char[]{'a', 'b', 'c', 'd'};
final char[] chars2 = new char[chars1.length];
final double[] double1 = new double[]{1.1, -2.2, 3.3, 4.4};
final double[] double2 = new double[double1.length];
final float[] float1 = new float[]{1.1f, 2.2f, -3.3f, 4.4f};
final float[] float2 = new float[float1.length];
final int[] ints1 = new int[]{1, 2, -3, 4};
final int[] ints2 = new int[ints1.length];
final long[] longs1 = new long[]{1L, -2L, 3L, -14L};
final long[] longs2 = new long[ints1.length];
final short[] shorts1 = new short[]{1, -2, 3, -14};
final short[] shorts2 = new short[ints1.length];
buffer.putByteArray(b1, 0, 2);
buffer.putByteArray(b1, 2, b1.length - 2);
buffer.putBooleanArray(bool1, 0, bool1.length);
buffer.putCharArray(chars1, 0, chars1.length);
buffer.putDoubleArray(double1, 0, double1.length);
buffer.putFloatArray(float1, 0, float1.length);
buffer.putIntArray(ints1, 0, ints1.length);
buffer.putLongArray(longs1, 0, longs1.length);
buffer.putShortArray(shorts1, 0, shorts1.length);
long pos = buffer.getPosition();
buffer.resetPosition();
buffer.getByteArray(b2, 0, b1.length);
buffer.getBooleanArray(bool2, 0, bool1.length);
buffer.getCharArray(chars2, 0, chars1.length);
buffer.getDoubleArray(double2, 0, double1.length);
buffer.getFloatArray(float2, 0, float1.length);
buffer.getIntArray(ints2, 0, ints1.length);
buffer.getLongArray(longs2, 0, longs1.length);
buffer.getShortArray(shorts2, 0, shorts1.length);
Assert.assertArrayEquals(b1, b2);
Assert.assertArrayEquals(bool1, bool2);
Assert.assertArrayEquals(chars1, chars2);
for (int i = 0; i < double1.length; i++) {
Assert.assertEquals(double1[i], double2[i], 0.0);
}
for (int i = 0; i < float1.length; i++) {
Assert.assertEquals(float1[i], float2[i], 0.0);
}
Assert.assertArrayEquals(ints1, ints2);
Assert.assertArrayEquals(longs1, longs2);
Assert.assertArrayEquals(shorts1, shorts2);
Assert.assertEquals(pos, buffer.getPosition());
}
@Test
public void testStartEndRegionAndDuplicate()
{
WritableBuffer buffer = getBuffer();
Assert.assertEquals(0, buffer.getPosition());
Assert.assertEquals(0, buffer.getStart());
Assert.assertEquals(CAPACITY, buffer.getEnd());
Assert.assertEquals(CAPACITY, buffer.getRemaining());
Assert.assertEquals(CAPACITY, buffer.getCapacity());
Assert.assertTrue(buffer.hasRemaining());
buffer.fill((byte) 0x07);
buffer.setAndCheckStartPositionEnd(10L, 15L, 100L);
Assert.assertEquals(15L, buffer.getPosition());
Assert.assertEquals(10L, buffer.getStart());
Assert.assertEquals(100L, buffer.getEnd());
Assert.assertEquals(85L, buffer.getRemaining());
Assert.assertEquals(CAPACITY, buffer.getCapacity());
buffer.fill((byte) 0x70);
buffer.resetPosition();
Assert.assertEquals(10L, buffer.getPosition());
for (int i = 0; i < 90; i++) {
if (i < 5) {
Assert.assertEquals(0x07, buffer.getByte());
} else {
Assert.assertEquals(0x70, buffer.getByte());
}
}
buffer.setAndCheckPosition(50);
Buffer duplicate = buffer.duplicate();
Assert.assertEquals(buffer.getStart(), duplicate.getStart());
Assert.assertEquals(buffer.getPosition(), duplicate.getPosition());
Assert.assertEquals(buffer.getEnd(), duplicate.getEnd());
Assert.assertEquals(buffer.getRemaining(), duplicate.getRemaining());
Assert.assertEquals(buffer.getCapacity(), duplicate.getCapacity());
duplicate.resetPosition();
for (int i = 0; i < 90; i++) {
if (i < 5) {
Assert.assertEquals(0x07, duplicate.getByte());
} else {
Assert.assertEquals(0x70, duplicate.getByte());
}
}
Buffer region = buffer.region(5L, 105L, buffer.getTypeByteOrder());
Assert.assertEquals(0, region.getStart());
Assert.assertEquals(0, region.getPosition());
Assert.assertEquals(105L, region.getEnd());
Assert.assertEquals(105L, region.getRemaining());
Assert.assertEquals(105L, region.getCapacity());
for (int i = 0; i < 105; i++) {
if (i < 10) {
Assert.assertEquals(0x07, region.getByte());
} else if (i < 95) {
Assert.assertEquals(0x70, region.getByte());
} else {
Assert.assertEquals(0x07, region.getByte());
}
}
}
@Test
public void testFill()
{
WritableBuffer buffer = getBuffer();
WritableBuffer anotherBuffer = getBuffer();
buffer.fill((byte) 0x0F);
anotherBuffer.fill((byte) 0x0F);
Assert.assertTrue(buffer.equalTo(0L, anotherBuffer, 0L, CAPACITY));
anotherBuffer.setPosition(100);
anotherBuffer.clear();
Assert.assertFalse(buffer.equalTo(0L, anotherBuffer, 0L, CAPACITY));
Assert.assertTrue(buffer.equalTo(0L, anotherBuffer, 0L, 100L));
}
private WritableBuffer getBuffer()
{
return getBuffer(CAPACITY);
}
private WritableBuffer getBuffer(int capacity)
{
final ByteBuffer aBuffer = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN);
SafeWritableBuffer memory = new SafeWritableBuffer(aBuffer);
return memory;
}
}

View File

@ -0,0 +1,359 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.data;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.memory.internal.UnsafeUtil;
import org.junit.Assert;
import org.junit.Test;
import java.io.CharArrayWriter;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class SafeWritableMemoryTest
{
private static final int CAPACITY = 1024;
@Test
public void testPutAndGet()
{
final WritableMemory memory = getMemory();
memory.putByte(3L, (byte) 0x01);
Assert.assertEquals(memory.getByte(3L), 0x01);
memory.putBoolean(1L, true);
Assert.assertTrue(memory.getBoolean(1L));
memory.putBoolean(1L, false);
Assert.assertFalse(memory.getBoolean(1L));
memory.putChar(10L, 'c');
Assert.assertEquals('c', memory.getChar(10L));
memory.putDouble(14L, 3.3);
Assert.assertEquals(3.3, memory.getDouble(14L), 0.0);
memory.putFloat(27L, 3.3f);
Assert.assertEquals(3.3f, memory.getFloat(27L), 0.0);
memory.putInt(11L, 1234);
Assert.assertEquals(1234, memory.getInt(11L));
memory.putLong(500L, 500L);
Assert.assertEquals(500L, memory.getLong(500L));
memory.putShort(11L, (short) 15);
Assert.assertEquals(15, memory.getShort(11L));
long l = memory.getAndSetLong(900L, 10L);
Assert.assertEquals(0L, l);
l = memory.getAndSetLong(900L, 100L);
Assert.assertEquals(10L, l);
l = memory.getAndAddLong(900L, 10L);
Assert.assertEquals(100L, l);
Assert.assertEquals(110L, memory.getLong(900L));
Assert.assertTrue(memory.compareAndSwapLong(900L, 110L, 120L));
Assert.assertFalse(memory.compareAndSwapLong(900L, 110L, 120L));
Assert.assertEquals(120L, memory.getLong(900L));
}
@Test
public void testPutAndGetArrays()
{
final WritableMemory memory = getMemory();
final byte[] b1 = new byte[]{0x01, 0x02, 0x08, 0x08};
final byte[] b2 = new byte[b1.length];
memory.putByteArray(12L, b1, 0, 3);
memory.putByteArray(15L, b1, 3, 1);
memory.getByteArray(12L, b2, 0, 3);
memory.getByteArray(15L, b2, 3, 1);
Assert.assertArrayEquals(b1, b2);
final boolean[] bool1 = new boolean[]{true, false, false, true};
final boolean[] bool2 = new boolean[bool1.length];
memory.putBooleanArray(100L, bool1, 0, 2);
memory.putBooleanArray(102L, bool1, 2, 2);
memory.getBooleanArray(100L, bool2, 0, 2);
memory.getBooleanArray(102L, bool2, 2, 2);
Assert.assertArrayEquals(bool1, bool2);
final char[] chars1 = new char[]{'a', 'b', 'c', 'd'};
final char[] chars2 = new char[chars1.length];
memory.putCharArray(10L, chars1, 0, 4);
memory.getCharArray(10L, chars2, 0, chars1.length);
Assert.assertArrayEquals(chars1, chars2);
final double[] double1 = new double[]{1.1, -2.2, 3.3, 4.4};
final double[] double2 = new double[double1.length];
memory.putDoubleArray(100L, double1, 0, 1);
memory.putDoubleArray(100L + Double.BYTES, double1, 1, 3);
memory.getDoubleArray(100L, double2, 0, 2);
memory.getDoubleArray(100L + (2 * Double.BYTES), double2, 2, 2);
for (int i = 0; i < double1.length; i++) {
Assert.assertEquals(double1[i], double2[i], 0.0);
}
final float[] float1 = new float[]{1.1f, 2.2f, -3.3f, 4.4f};
final float[] float2 = new float[float1.length];
memory.putFloatArray(100L, float1, 0, 1);
memory.putFloatArray(100L + Float.BYTES, float1, 1, 3);
memory.getFloatArray(100L, float2, 0, 2);
memory.getFloatArray(100L + (2 * Float.BYTES), float2, 2, 2);
for (int i = 0; i < float1.length; i++) {
Assert.assertEquals(float1[i], float2[i], 0.0);
}
final int[] ints1 = new int[]{1, 2, -3, 4};
final int[] ints2 = new int[ints1.length];
memory.putIntArray(100L, ints1, 0, 1);
memory.putIntArray(100L + Integer.BYTES, ints1, 1, 3);
memory.getIntArray(100L, ints2, 0, 2);
memory.getIntArray(100L + (2 * Integer.BYTES), ints2, 2, 2);
Assert.assertArrayEquals(ints1, ints2);
final long[] longs1 = new long[]{1L, -2L, 3L, -14L};
final long[] longs2 = new long[ints1.length];
memory.putLongArray(100L, longs1, 0, 1);
memory.putLongArray(100L + Long.BYTES, longs1, 1, 3);
memory.getLongArray(100L, longs2, 0, 2);
memory.getLongArray(100L + (2 * Long.BYTES), longs2, 2, 2);
Assert.assertArrayEquals(longs1, longs2);
final short[] shorts1 = new short[]{1, -2, 3, -14};
final short[] shorts2 = new short[ints1.length];
memory.putShortArray(100L, shorts1, 0, 1);
memory.putShortArray(100L + Short.BYTES, shorts1, 1, 3);
memory.getShortArray(100L, shorts2, 0, 2);
memory.getShortArray(100L + (2 * Short.BYTES), shorts2, 2, 2);
Assert.assertArrayEquals(shorts1, shorts2);
}
@Test
public void testFill()
{
final byte theByte = 0x01;
final byte anotherByte = 0x02;
final WritableMemory memory = getMemory();
final int halfWay = (int) (memory.getCapacity() / 2);
memory.fill(theByte);
for (int i = 0; i < memory.getCapacity(); i++) {
Assert.assertEquals(theByte, memory.getByte(i));
}
memory.fill(halfWay, memory.getCapacity() - halfWay, anotherByte);
for (int i = 0; i < memory.getCapacity(); i++) {
if (i < halfWay) {
Assert.assertEquals(theByte, memory.getByte(i));
} else {
Assert.assertEquals(anotherByte, memory.getByte(i));
}
}
memory.clear(halfWay, memory.getCapacity() - halfWay);
for (int i = 0; i < memory.getCapacity(); i++) {
if (i < halfWay) {
Assert.assertEquals(theByte, memory.getByte(i));
} else {
Assert.assertEquals(0, memory.getByte(i));
}
}
memory.setBits(halfWay - 1, anotherByte);
Assert.assertEquals(0x03, memory.getByte(halfWay - 1));
memory.clearBits(halfWay - 1, theByte);
Assert.assertEquals(anotherByte, memory.getByte(halfWay - 1));
memory.clear();
for (int i = 0; i < memory.getCapacity(); i++) {
Assert.assertEquals(0, memory.getByte(i));
}
}
@Test
public void testStringStuff() throws IOException
{
WritableMemory memory = getMemory();
String s1 = "hello ";
memory.putCharsToUtf8(10L, s1);
StringBuilder builder = new StringBuilder();
memory.getCharsFromUtf8(10L, s1.length(), builder);
Assert.assertEquals(s1, builder.toString());
CharArrayWriter someAppendable = new CharArrayWriter();
memory.getCharsFromUtf8(10L, s1.length(), someAppendable);
Assert.assertEquals(s1, someAppendable.toString());
}
@Test
public void testRegion()
{
WritableMemory memory = getMemory();
Assert.assertEquals(CAPACITY, memory.getCapacity());
Assert.assertEquals(0, memory.getCumulativeOffset());
Assert.assertEquals(10L, memory.getCumulativeOffset(10L));
Assert.assertThrows(
IllegalArgumentException.class,
() -> memory.checkValidAndBounds(CAPACITY - 10, 11L)
);
final byte[] someBytes = new byte[]{0x01, 0x02, 0x03, 0x04};
memory.putByteArray(10L, someBytes, 0, someBytes.length);
Memory region = memory.region(10L, someBytes.length);
Assert.assertEquals(someBytes.length, region.getCapacity());
Assert.assertEquals(0, region.getCumulativeOffset());
Assert.assertEquals(2L, region.getCumulativeOffset(2L));
Assert.assertThrows(
IllegalArgumentException.class,
() -> region.checkValidAndBounds(2L, 4L)
);
final byte[] andBack = new byte[someBytes.length];
region.getByteArray(0L, andBack, 0, someBytes.length);
Assert.assertArrayEquals(someBytes, andBack);
Memory differentOrderRegion = memory.region(10L, someBytes.length, ByteOrder.BIG_ENDIAN);
// different order
Assert.assertFalse(region.isByteOrderCompatible(differentOrderRegion.getTypeByteOrder()));
// contents are equal tho
Assert.assertTrue(region.equalTo(0L, differentOrderRegion, 0L, someBytes.length));
}
@Test
public void testCompareAndEquals()
{
WritableMemory memory = getMemory();
final byte[] someBytes = new byte[]{0x01, 0x02, 0x03, 0x04};
final byte[] shorterSameBytes = new byte[]{0x01, 0x02, 0x03};
final byte[] differentBytes = new byte[]{0x02, 0x02, 0x03, 0x04};
memory.putByteArray(10L, someBytes, 0, someBytes.length);
memory.putByteArray(400L, someBytes, 0, someBytes.length);
memory.putByteArray(200L, shorterSameBytes, 0, shorterSameBytes.length);
memory.putByteArray(500L, differentBytes, 0, differentBytes.length);
Assert.assertEquals(0, memory.compareTo(10L, someBytes.length, memory, 400L, someBytes.length));
Assert.assertEquals(4, memory.compareTo(10L, someBytes.length, memory, 200L, someBytes.length));
Assert.assertEquals(-1, memory.compareTo(10L, someBytes.length, memory, 500L, differentBytes.length));
WritableMemory memory2 = getMemory();
memory2.putByteArray(0L, someBytes, 0, someBytes.length);
Assert.assertEquals(0, memory.compareTo(10L, someBytes.length, memory2, 0L, someBytes.length));
Assert.assertTrue(memory.equalTo(10L, memory2, 0L, someBytes.length));
WritableMemory memory3 = getMemory();
memory2.copyTo(0L, memory3, 0L, CAPACITY);
Assert.assertTrue(memory2.equalTo(0L, memory3, 0L, CAPACITY));
}
@Test
public void testHash()
{
WritableMemory memory = getMemory();
final long[] someLongs = new long[]{1L, 10L, 100L, 1000L, 10000L};
final int[] someInts = new int[]{1, 2, 3};
final byte[] someBytes = new byte[]{0x01, 0x02, 0x03};
final int longsLength = Long.BYTES * someLongs.length;
final int someIntsLength = Integer.BYTES * someInts.length;
final int totalLength = longsLength + someIntsLength + someBytes.length;
memory.putLongArray(2L, someLongs, 0, someLongs.length);
memory.putIntArray(2L + longsLength, someInts, 0, someInts.length);
memory.putByteArray(2L + longsLength + someIntsLength, someBytes, 0, someBytes.length);
Memory memory2 = Memory.wrap(memory.getByteBuffer(), ByteOrder.LITTLE_ENDIAN);
Assert.assertEquals(
memory2.xxHash64(2L, totalLength, 0),
memory.xxHash64(2L, totalLength, 0)
);
Assert.assertEquals(
memory2.xxHash64(2L, 0),
memory.xxHash64(2L, 0)
);
}
@Test
public void testToHexString()
{
final byte[] bytes = new byte[]{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07};
final WritableMemory memory = getMemory(bytes.length);
memory.putByteArray(0L, bytes, 0, bytes.length);
final long hcode = memory.hashCode() & 0XFFFFFFFFL;
final long bufferhcode = memory.getByteBuffer().hashCode() & 0XFFFFFFFFL;
final long reqhcode = memory.getMemoryRequestServer().hashCode() & 0XFFFFFFFFL;
Assert.assertEquals(
"### SafeWritableMemory SUMMARY ###\n"
+ "Header Comment : test memory dump\n"
+ "Call Parameters : .toHexString(..., 0, 8), hashCode: " + hcode + "\n"
+ "UnsafeObj, hashCode : null\n"
+ "UnsafeObjHeader : 0\n"
+ "ByteBuf, hashCode : HeapByteBuffer, " + bufferhcode + "\n"
+ "RegionOffset : 0\n"
+ "Capacity : 8\n"
+ "CumBaseOffset : 0\n"
+ "MemReq, hashCode : HeapByteBufferMemoryRequestServer, " + reqhcode + "\n"
+ "Valid : true\n"
+ "Read Only : false\n"
+ "Type Byte Order : LITTLE_ENDIAN\n"
+ "Native Byte Order : LITTLE_ENDIAN\n"
+ "JDK Runtime Version : " + UnsafeUtil.JDK + "\n"
+ "Data, littleEndian : 0 1 2 3 4 5 6 7\n"
+ " 0: 00 01 02 03 04 05 06 07 \n",
memory.toHexString("test memory dump", 0, bytes.length)
);
}
@Test
public void testMisc()
{
WritableMemory memory = getMemory(10);
WritableMemory memory2 = memory.getMemoryRequestServer().request(memory, 20);
Assert.assertEquals(20, memory2.getCapacity());
Assert.assertFalse(memory2.hasArray());
Assert.assertFalse(memory2.isReadOnly());
Assert.assertFalse(memory2.isDirect());
Assert.assertTrue(memory2.isValid());
Assert.assertTrue(memory2.hasByteBuffer());
Assert.assertFalse(memory2.isSameResource(memory));
Assert.assertTrue(memory2.isSameResource(memory2));
// does nothing
memory.getMemoryRequestServer().requestClose(memory, memory2);
}
private WritableMemory getMemory()
{
return getMemory(CAPACITY);
}
private WritableMemory getMemory(int capacity)
{
final ByteBuffer aBuffer = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN);
return SafeWritableMemory.wrap(aBuffer);
}
}